[Mlir-commits] [mlir] 003b28d - [mlir] Move affine's FoldMemRefAliasOps into its own pass (#172548)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 2 10:13:46 PST 2026
Author: Krzysztof Drewniak
Date: 2026-01-02T10:13:42-08:00
New Revision: 003b28d0310dcb89e9c537a5d431c1d7b0d492f3
URL: https://github.com/llvm/llvm-project/commit/003b28d0310dcb89e9c537a5d431c1d7b0d492f3
DIFF: https://github.com/llvm/llvm-project/commit/003b28d0310dcb89e9c537a5d431c1d7b0d492f3.diff
LOG: [mlir] Move affine's FoldMemRefAliasOps into its own pass (#172548)
I'm planning to introduce an interface that'll allow FoldMemRefAliasOps
to not know about dialects like NVVM or GPU. To do this, however, I need
to get the `affine` ops (which need special handling in order to handle
their implicit affine maps) into a separate pass, analogously to how
`amdgpu` ops have these patterns under their dialect and ton under
`memref`.
This commit also changes the expand/collapse_shape index resolvers to
return `void`, since they never actually failed and to make it clearer
that they modify IR.
(Note: An LLM did the initial refactoring and test movement, I've
reviewed the results and edited them some.)
Added:
mlir/lib/Dialect/Affine/Transforms/FoldMemRefAliasOps.cpp
mlir/test/Dialect/Affine/fold-memref-alias-ops.mlir
Modified:
mlir/include/mlir/Dialect/Affine/Passes.h
mlir/include/mlir/Dialect/Affine/Passes.td
mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h
index ec349ec48e33b..58596733a76c0 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.h
+++ b/mlir/include/mlir/Dialect/Affine/Passes.h
@@ -19,6 +19,7 @@
#include <limits>
namespace mlir {
+class RewritePatternSet;
namespace func {
class FuncOp;
@@ -126,6 +127,10 @@ std::unique_ptr<Pass> createAffineExpandIndexOpsPass();
/// operations.
std::unique_ptr<Pass> createAffineExpandIndexOpsAsAffinePass();
+/// Appends patterns for folding memref aliasing ops into affine load/store
+/// ops into `patterns`.
+void populateAffineFoldMemRefAliasOpPatterns(RewritePatternSet &patterns);
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td
index 1ad5344121d4e..430edffc29038 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Passes.td
@@ -440,4 +440,15 @@ def AffineExpandIndexOpsAsAffine : Pass<"affine-expand-index-ops-as-affine"> {
let constructor = "mlir::affine::createAffineExpandIndexOpsAsAffinePass()";
}
+def AffineFoldMemRefAliasOps : Pass<"affine-fold-memref-alias-ops"> {
+ let summary = "Fold memref alias ops into affine memory ops";
+ let description = [{
+ The pass folds memref.subview, memref.expand_shape, and memref.collapse_shape
+ operations into affine memory operations (currently only `affine.load` and
+ `affine.store`) . This is similar to the `fold-memref-alias-ops` pass in the
+ `memref` dialect but adds handling specific to affine operations.
+ }];
+ let dependentDialects = ["memref::MemRefDialect"];
+}
+
#endif // MLIR_DIALECT_AFFINE_PASSES
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index c403386bd214a..d04ae101fe1dc 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -22,7 +22,7 @@ def FoldMemRefAliasOpsPass : Pass<"fold-memref-alias-ops"> {
from/to the original memref.
}];
let dependentDialects = [
- "affine::AffineDialect", "memref::MemRefDialect", "vector::VectorDialect"
+ "memref::MemRefDialect", "vector::VectorDialect"
];
}
@@ -197,7 +197,7 @@ def ReifyResultShapesPass : Pass<"reify-result-shapes"> {
let description = [{
This pass reifies the shapes of a subset of `ReifyRankedShapedTypeOpInterface`
ops with `tensor` results.
-
+
The pass currently only supports result shape type reification for:
- tensor::PadOp
- tensor::ConcatOp
diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
index dd3b3dea6ef26..5d2429bb476e6 100644
--- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
+++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
@@ -118,7 +118,7 @@ MemrefValue skipViewLikeOps(MemrefValue source);
/// Given the 'indices' of a load/store operation where the memref is a result
/// of a expand_shape op, returns the indices w.r.t to the source memref of the
-/// expand_shape op. For example
+/// expand_shape op into `sourceIndices`. For example
///
/// %0 = ... : memref<12x42xf32>
/// %1 = memref.expand_shape %0 [[0, 1], [2]]
@@ -129,14 +129,19 @@ MemrefValue skipViewLikeOps(MemrefValue source);
///
/// %2 = load %0[6 * i1 + i2, %i3] :
/// memref<12x42xf32>
-LogicalResult resolveSourceIndicesExpandShape(
- Location loc, PatternRewriter &rewriter,
- memref::ExpandShapeOp expandShapeOp, ValueRange indices,
- SmallVectorImpl<Value> &sourceIndices, bool startsInbounds);
+///
+/// If `startsInbounds` is true, optimizations that rely on all indices being
+/// non-negative and less than the corresponding memref dimension may be
+/// performed.
+void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
+ memref::ExpandShapeOp expandShapeOp,
+ ValueRange indices,
+ SmallVectorImpl<Value> &sourceIndices,
+ bool startsInbounds);
/// Given the 'indices' of a load/store operation where the memref is a result
/// of a collapse_shape op, returns the indices w.r.t to the source memref of
-/// the collapse_shape op. For example
+/// the collapse_shape op, returing them into `sourceIndices`. For example
///
/// %0 = ... : memref<2x6x42xf32>
/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
@@ -147,11 +152,10 @@ LogicalResult resolveSourceIndicesExpandShape(
///
/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
/// memref<2x6x42xf32>
-LogicalResult
-resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
- memref::CollapseShapeOp collapseShapeOp,
- ValueRange indices,
- SmallVectorImpl<Value> &sourceIndices);
+void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
+ memref::CollapseShapeOp collapseShapeOp,
+ ValueRange indices,
+ SmallVectorImpl<Value> &sourceIndices);
} // namespace memref
} // namespace mlir
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
index d54751098410b..fb2b096df9c3d 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
@@ -46,21 +46,15 @@ static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc,
return success();
})
.Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
- if (failed(mlir::memref::resolveSourceIndicesExpandShape(
- loc, rewriter, expandShapeOp, indices, resolvedIndices,
- false))) {
- return failure();
- }
+ mlir::memref::resolveSourceIndicesExpandShape(
+ loc, rewriter, expandShapeOp, indices, resolvedIndices, false);
memrefBase = expandShapeOp.getViewSource();
return success();
})
.Case<memref::CollapseShapeOp>(
[&](memref::CollapseShapeOp collapseShapeOp) {
- if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
- loc, rewriter, collapseShapeOp, indices,
- resolvedIndices))) {
- return failure();
- }
+ mlir::memref::resolveSourceIndicesCollapseShape(
+ loc, rewriter, collapseShapeOp, indices, resolvedIndices);
memrefBase = collapseShapeOp.getViewSource();
return success();
})
diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
index c792200f4a49a..7bce124817032 100644
--- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
AffineParallelize.cpp
AffineScalarReplacement.cpp
DecomposeAffineOps.cpp
+ FoldMemRefAliasOps.cpp
LoopCoalescing.cpp
LoopFusion.cpp
LoopTiling.cpp
@@ -34,6 +35,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
MLIRArithDialect
MLIRIR
MLIRMemRefDialect
+ MLIRMemRefUtils
MLIRPass
MLIRSCFUtils
MLIRSideEffectInterfaces
diff --git a/mlir/lib/Dialect/Affine/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/Affine/Transforms/FoldMemRefAliasOps.cpp
new file mode 100644
index 0000000000000..febdb94970431
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/Transforms/FoldMemRefAliasOps.cpp
@@ -0,0 +1,253 @@
+//===- FoldMemRefAliasOps.cpp - Fold memref alias ops for affine ops ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass contains affine-specif versions of the folding patterns for
+// memref.expand_shape, memref.collapse_shape, and memref.subview, since
+// those all need affine-specific handling that won't fit a general interface.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Passes.h"
+#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace affine {
+#define GEN_PASS_DEF_AFFINEFOLDMEMREFALIASOPS
+#include "mlir/Dialect/Affine/Passes.h.inc"
+} // namespace affine
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::affine;
+
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+/// Given an AffineMap and a list of indices, apply the map to get the
+/// underlying indices (expanding the affine map).
+static void expandToUnderlyingIndices(AffineMap affineMap, ValueRange indices,
+ Location loc, PatternRewriter &rewriter,
+ SmallVectorImpl<Value> &result) {
+ SmallVector<OpFoldResult> indicesOfr(
+ llvm::map_to_vector(indices, [](Value v) -> OpFoldResult { return v; }));
+ for (unsigned i : llvm::seq(0u, affineMap.getNumResults())) {
+ OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, affineMap.getSubMap({i}), indicesOfr);
+ result.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct AffineLoadOpOfSubViewOpFolder final : OpRewritePattern<AffineLoadOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(AffineLoadOp loadOp,
+ PatternRewriter &rewriter) const override {
+ auto subViewOp = loadOp.getMemref().getDefiningOp<memref::SubViewOp>();
+
+ if (!subViewOp)
+ return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
+
+ SmallVector<Value> indices;
+ expandToUnderlyingIndices(loadOp.getAffineMap(), loadOp.getIndices(),
+ loadOp.getLoc(), rewriter, indices);
+
+ SmallVector<Value> sourceIndices;
+ affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+ rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
+ subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
+ sourceIndices);
+
+ rewriter.replaceOpWithNewOp<AffineLoadOp>(loadOp, subViewOp.getSource(),
+ sourceIndices);
+ return success();
+ }
+};
+
+struct AffineLoadOpOfExpandShapeOpFolder final
+ : OpRewritePattern<AffineLoadOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(AffineLoadOp loadOp,
+ PatternRewriter &rewriter) const override {
+ auto expandShapeOp =
+ loadOp.getMemref().getDefiningOp<memref::ExpandShapeOp>();
+
+ if (!expandShapeOp)
+ return failure();
+
+ SmallVector<Value> indices;
+ expandToUnderlyingIndices(loadOp.getAffineMap(), loadOp.getIndices(),
+ loadOp.getLoc(), rewriter, indices);
+
+ SmallVector<Value> sourceIndices;
+ // affine.load guarantees that indexes start inbounds, which impacts if our
+ // linearization is `disjoint`.
+ memref::resolveSourceIndicesExpandShape(
+ loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
+ /*startsInbounds=*/true);
+
+ rewriter.replaceOpWithNewOp<AffineLoadOp>(
+ loadOp, expandShapeOp.getViewSource(), sourceIndices);
+ return success();
+ }
+};
+
+struct AffineLoadOpOfCollapseShapeOpFolder final
+ : OpRewritePattern<AffineLoadOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(AffineLoadOp loadOp,
+ PatternRewriter &rewriter) const override {
+ auto collapseShapeOp =
+ loadOp.getMemref().getDefiningOp<memref::CollapseShapeOp>();
+
+ if (!collapseShapeOp)
+ return failure();
+
+ SmallVector<Value> indices;
+ expandToUnderlyingIndices(loadOp.getAffineMap(), loadOp.getIndices(),
+ loadOp.getLoc(), rewriter, indices);
+
+ SmallVector<Value> sourceIndices;
+ memref::resolveSourceIndicesCollapseShape(
+ loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices);
+
+ rewriter.replaceOpWithNewOp<AffineLoadOp>(
+ loadOp, collapseShapeOp.getViewSource(), sourceIndices);
+ return success();
+ }
+};
+
+struct AffineStoreOpOfSubViewOpFolder final : OpRewritePattern<AffineStoreOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(AffineStoreOp storeOp,
+ PatternRewriter &rewriter) const override {
+ auto subViewOp = storeOp.getMemref().getDefiningOp<memref::SubViewOp>();
+
+ if (!subViewOp)
+ return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
+
+ // For affine ops, we need to apply the map to get the "actual" indices.
+ SmallVector<Value> indices;
+ expandToUnderlyingIndices(storeOp.getAffineMap(), storeOp.getIndices(),
+ storeOp.getLoc(), rewriter, indices);
+
+ SmallVector<Value> sourceIndices;
+ affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+ rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
+ subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
+ sourceIndices);
+
+ rewriter.replaceOpWithNewOp<AffineStoreOp>(
+ storeOp, storeOp.getValue(), subViewOp.getSource(), sourceIndices);
+ return success();
+ }
+};
+
+struct AffineStoreOpOfExpandShapeOpFolder final
+ : OpRewritePattern<AffineStoreOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(AffineStoreOp storeOp,
+ PatternRewriter &rewriter) const override {
+ auto expandShapeOp =
+ storeOp.getMemref().getDefiningOp<memref::ExpandShapeOp>();
+
+ if (!expandShapeOp)
+ return failure();
+
+ SmallVector<Value> indices;
+ expandToUnderlyingIndices(storeOp.getAffineMap(), storeOp.getIndices(),
+ storeOp.getLoc(), rewriter, indices);
+
+ SmallVector<Value> sourceIndices;
+ // affine.store guarantees that indexes start inbounds, which impacts if our
+ // linearization is `disjoint`.
+ memref::resolveSourceIndicesExpandShape(
+ storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
+ /*startsInbounds=*/true);
+
+ rewriter.replaceOpWithNewOp<AffineStoreOp>(
+ storeOp, storeOp.getValueToStore(), expandShapeOp.getViewSource(),
+ sourceIndices);
+ return success();
+ }
+};
+
+struct AffineStoreOpOfCollapseShapeOpFolder final
+ : OpRewritePattern<AffineStoreOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(AffineStoreOp storeOp,
+ PatternRewriter &rewriter) const override {
+ auto collapseShapeOp =
+ storeOp.getMemref().getDefiningOp<memref::CollapseShapeOp>();
+
+ if (!collapseShapeOp)
+ return failure();
+
+ // For affine ops, we need to apply the map to get the "actual" indices.
+ SmallVector<Value> indices;
+ expandToUnderlyingIndices(storeOp.getAffineMap(), storeOp.getIndices(),
+ storeOp.getLoc(), rewriter, indices);
+
+ SmallVector<Value> sourceIndices;
+ memref::resolveSourceIndicesCollapseShape(
+ storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices);
+
+ rewriter.replaceOpWithNewOp<AffineStoreOp>(
+ storeOp, storeOp.getValueToStore(), collapseShapeOp.getViewSource(),
+ sourceIndices);
+ return success();
+ }
+};
+
+} // namespace
+
+void affine::populateAffineFoldMemRefAliasOpPatterns(
+ RewritePatternSet &patterns) {
+ patterns
+ .add<AffineLoadOpOfSubViewOpFolder, AffineLoadOpOfExpandShapeOpFolder,
+ AffineLoadOpOfCollapseShapeOpFolder, AffineStoreOpOfSubViewOpFolder,
+ AffineStoreOpOfExpandShapeOpFolder,
+ AffineStoreOpOfCollapseShapeOpFolder>(patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// Pass registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct AffineFoldMemRefAliasOpsPass final
+ : public affine::impl::AffineFoldMemRefAliasOpsBase<
+ AffineFoldMemRefAliasOpsPass> {
+ void runOnOperation() override;
+};
+
+} // namespace
+
+void AffineFoldMemRefAliasOpsPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ affine::populateAffineFoldMemRefAliasOpPatterns(patterns);
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 3667fdb2bb728..3cacb7e29263b 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -11,7 +11,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -198,22 +197,6 @@ class NVGPUAsyncCopyOpSubViewOpFolder final
};
} // namespace
-static SmallVector<Value>
-calculateExpandedAccessIndices(AffineMap affineMap,
- const SmallVector<Value> &indices, Location loc,
- PatternRewriter &rewriter) {
- SmallVector<OpFoldResult> indicesOfr(llvm::to_vector(
- llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; })));
- SmallVector<Value> expandedIndices;
- for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++) {
- OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
- rewriter, loc, affineMap.getSubMap({i}), indicesOfr);
- expandedIndices.push_back(
- getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
- }
- return expandedIndices;
-}
-
template <typename XferOp>
static LogicalResult
preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp,
@@ -262,28 +245,13 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
if (failed(preconditionResult))
return preconditionResult;
- SmallVector<Value> indices(loadOp.getIndices().begin(),
- loadOp.getIndices().end());
- // For affine ops, we need to apply the map to get the operands to get the
- // "actual" indices.
- if (auto affineLoadOp =
- dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
- AffineMap affineMap = affineLoadOp.getAffineMap();
- auto expandedIndices = calculateExpandedAccessIndices(
- affineMap, indices, loadOp.getLoc(), rewriter);
- indices.assign(expandedIndices.begin(), expandedIndices.end());
- }
SmallVector<Value> sourceIndices;
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
- subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
- sourceIndices);
+ subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
+ loadOp.getIndices(), sourceIndices);
llvm::TypeSwitch<Operation *, void>(loadOp)
- .Case([&](affine::AffineLoadOp op) {
- rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
- loadOp, subViewOp.getSource(), sourceIndices);
- })
.Case([&](memref::LoadOp op) {
rewriter.replaceOpWithNewOp<memref::LoadOp>(
loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
@@ -328,32 +296,14 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
if (!expandShapeOp)
return failure();
- SmallVector<Value> indices(loadOp.getIndices().begin(),
- loadOp.getIndices().end());
- // For affine ops, we need to apply the map to get the operands to get the
- // "actual" indices.
- if (auto affineLoadOp =
- dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
- AffineMap affineMap = affineLoadOp.getAffineMap();
- auto expandedIndices = calculateExpandedAccessIndices(
- affineMap, indices, loadOp.getLoc(), rewriter);
- indices.assign(expandedIndices.begin(), expandedIndices.end());
- }
SmallVector<Value> sourceIndices;
- // memref.load and affine.load guarantee that indexes start inbounds
- // while the vector operations don't. This impacts if our linearization
- // is `disjoint`
- if (failed(resolveSourceIndicesExpandShape(
- loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
- isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
- return failure();
+ // memref.load guarantees that indexes start inbounds while the vector
+ // operations don't. This impacts if our linearization is `disjoint`
+ resolveSourceIndicesExpandShape(loadOp.getLoc(), rewriter, expandShapeOp,
+ loadOp.getIndices(), sourceIndices,
+ isa<memref::LoadOp>(loadOp.getOperation()));
return llvm::TypeSwitch<Operation *, LogicalResult>(loadOp)
- .Case([&](affine::AffineLoadOp op) {
- rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
- loadOp, expandShapeOp.getViewSource(), sourceIndices);
- return success();
- })
.Case([&](memref::LoadOp op) {
rewriter.replaceOpWithNewOp<memref::LoadOp>(
loadOp, expandShapeOp.getViewSource(), sourceIndices,
@@ -407,26 +357,10 @@ LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
if (!collapseShapeOp)
return failure();
- SmallVector<Value> indices(loadOp.getIndices().begin(),
- loadOp.getIndices().end());
- // For affine ops, we need to apply the map to get the operands to get the
- // "actual" indices.
- if (auto affineLoadOp =
- dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
- AffineMap affineMap = affineLoadOp.getAffineMap();
- auto expandedIndices = calculateExpandedAccessIndices(
- affineMap, indices, loadOp.getLoc(), rewriter);
- indices.assign(expandedIndices.begin(), expandedIndices.end());
- }
SmallVector<Value> sourceIndices;
- if (failed(resolveSourceIndicesCollapseShape(
- loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
- return failure();
+ resolveSourceIndicesCollapseShape(loadOp.getLoc(), rewriter, collapseShapeOp,
+ loadOp.getIndices(), sourceIndices);
llvm::TypeSwitch<Operation *, void>(loadOp)
- .Case([&](affine::AffineLoadOp op) {
- rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
- loadOp, collapseShapeOp.getViewSource(), sourceIndices);
- })
.Case([&](memref::LoadOp op) {
rewriter.replaceOpWithNewOp<memref::LoadOp>(
loadOp, collapseShapeOp.getViewSource(), sourceIndices,
@@ -460,28 +394,13 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
if (failed(preconditionResult))
return preconditionResult;
- SmallVector<Value> indices(storeOp.getIndices().begin(),
- storeOp.getIndices().end());
- // For affine ops, we need to apply the map to get the operands to get the
- // "actual" indices.
- if (auto affineStoreOp =
- dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
- AffineMap affineMap = affineStoreOp.getAffineMap();
- auto expandedIndices = calculateExpandedAccessIndices(
- affineMap, indices, storeOp.getLoc(), rewriter);
- indices.assign(expandedIndices.begin(), expandedIndices.end());
- }
SmallVector<Value> sourceIndices;
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
- subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
- sourceIndices);
+ subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
+ storeOp.getIndices(), sourceIndices);
llvm::TypeSwitch<Operation *, void>(storeOp)
- .Case([&](affine::AffineStoreOp op) {
- rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
- op, op.getValue(), subViewOp.getSource(), sourceIndices);
- })
.Case([&](memref::StoreOp op) {
rewriter.replaceOpWithNewOp<memref::StoreOp>(
op, op.getValue(), subViewOp.getSource(), sourceIndices,
@@ -522,31 +441,13 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
if (!expandShapeOp)
return failure();
- SmallVector<Value> indices(storeOp.getIndices().begin(),
- storeOp.getIndices().end());
- // For affine ops, we need to apply the map to get the operands to get the
- // "actual" indices.
- if (auto affineStoreOp =
- dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
- AffineMap affineMap = affineStoreOp.getAffineMap();
- auto expandedIndices = calculateExpandedAccessIndices(
- affineMap, indices, storeOp.getLoc(), rewriter);
- indices.assign(expandedIndices.begin(), expandedIndices.end());
- }
SmallVector<Value> sourceIndices;
- // memref.store and affine.store guarantee that indexes start inbounds
- // while the vector operations don't. This impacts if our linearization
- // is `disjoint`
- if (failed(resolveSourceIndicesExpandShape(
- storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
- isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
- return failure();
+ // memref.store guarantees that indexes start inbounds while the vector
+ // operations don't. This impacts if our linearization is `disjoint`
+ resolveSourceIndicesExpandShape(storeOp.getLoc(), rewriter, expandShapeOp,
+ storeOp.getIndices(), sourceIndices,
+ isa<memref::StoreOp>(storeOp.getOperation()));
llvm::TypeSwitch<Operation *, void>(storeOp)
- .Case([&](affine::AffineStoreOp op) {
- rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
- storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
- sourceIndices);
- })
.Case([&](memref::StoreOp op) {
rewriter.replaceOpWithNewOp<memref::StoreOp>(
storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
@@ -575,27 +476,10 @@ LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
if (!collapseShapeOp)
return failure();
- SmallVector<Value> indices(storeOp.getIndices().begin(),
- storeOp.getIndices().end());
- // For affine ops, we need to apply the map to get the operands to get the
- // "actual" indices.
- if (auto affineStoreOp =
- dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
- AffineMap affineMap = affineStoreOp.getAffineMap();
- auto expandedIndices = calculateExpandedAccessIndices(
- affineMap, indices, storeOp.getLoc(), rewriter);
- indices.assign(expandedIndices.begin(), expandedIndices.end());
- }
SmallVector<Value> sourceIndices;
- if (failed(resolveSourceIndicesCollapseShape(
- storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
- return failure();
+ resolveSourceIndicesCollapseShape(storeOp.getLoc(), rewriter, collapseShapeOp,
+ storeOp.getIndices(), sourceIndices);
llvm::TypeSwitch<Operation *, void>(storeOp)
- .Case([&](affine::AffineStoreOp op) {
- rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
- storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
- sourceIndices);
- })
.Case([&](memref::StoreOp op) {
rewriter.replaceOpWithNewOp<memref::StoreOp>(
storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
@@ -630,29 +514,27 @@ LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
"source or destination");
// If the source is a subview, we need to resolve the indices.
- SmallVector<Value> srcindices(copyOp.getSrcIndices().begin(),
- copyOp.getSrcIndices().end());
- SmallVector<Value> foldedSrcIndices(srcindices);
+ SmallVector<Value> foldedSrcIndices(copyOp.getSrcIndices().begin(),
+ copyOp.getSrcIndices().end());
if (srcSubViewOp) {
LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n");
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
- srcindices, foldedSrcIndices);
+ copyOp.getSrcIndices(), foldedSrcIndices);
}
// If the destination is a subview, we need to resolve the indices.
- SmallVector<Value> dstindices(copyOp.getDstIndices().begin(),
- copyOp.getDstIndices().end());
- SmallVector<Value> foldedDstIndices(dstindices);
+ SmallVector<Value> foldedDstIndices(copyOp.getDstIndices().begin(),
+ copyOp.getDstIndices().end());
if (dstSubViewOp) {
LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n");
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
- dstindices, foldedDstIndices);
+ copyOp.getDstIndices(), foldedDstIndices);
}
// Replace the copy op with a new copy op that uses the source and destination
@@ -669,33 +551,27 @@ LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
}
void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
- patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>,
- LoadOpOfSubViewOpFolder<memref::LoadOp>,
+ patterns.add<LoadOpOfSubViewOpFolder<memref::LoadOp>,
LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
LoadOpOfSubViewOpFolder<vector::LoadOp>,
LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
- StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
StoreOpOfSubViewOpFolder<memref::StoreOp>,
StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
StoreOpOfSubViewOpFolder<vector::StoreOp>,
StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
- LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>,
- StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
- LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
- StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index e5486988947c6..2d341dce665e5 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -223,10 +223,11 @@ MemrefValue skipViewLikeOps(MemrefValue source) {
return source;
}
-LogicalResult resolveSourceIndicesExpandShape(
- Location loc, PatternRewriter &rewriter,
- memref::ExpandShapeOp expandShapeOp, ValueRange indices,
- SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
+void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
+ memref::ExpandShapeOp expandShapeOp,
+ ValueRange indices,
+ SmallVectorImpl<Value> &sourceIndices,
+ bool startsInbounds) {
SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
// Traverse all reassociation groups to determine the appropriate indices
@@ -246,14 +247,12 @@ LogicalResult resolveSourceIndicesExpandShape(
rewriter, loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
sourceIndices.push_back(collapsedIndex);
}
- return success();
}
-LogicalResult
-resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
- memref::CollapseShapeOp collapseShapeOp,
- ValueRange indices,
- SmallVectorImpl<Value> &sourceIndices) {
+void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
+ memref::CollapseShapeOp collapseShapeOp,
+ ValueRange indices,
+ SmallVectorImpl<Value> &sourceIndices) {
// Note: collapse_shape requires a strided memref, we can do this.
auto metadata = memref::ExtractStridedMetadataOp::create(
rewriter, loc, collapseShapeOp.getSrc());
@@ -285,7 +284,6 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
}
}
- return success();
}
} // namespace memref
diff --git a/mlir/test/Dialect/Affine/fold-memref-alias-ops.mlir b/mlir/test/Dialect/Affine/fold-memref-alias-ops.mlir
new file mode 100644
index 0000000000000..5e3e107531802
--- /dev/null
+++ b/mlir/test/Dialect/Affine/fold-memref-alias-ops.mlir
@@ -0,0 +1,191 @@
+// RUN: mlir-opt -affine-fold-memref-alias-ops -split-input-file %s | FileCheck %s
+
+// Tests for folding memref aliasing ops (expand/collapse_shape and subview) into
+// affine memory access operations, which require specific handling (and thus
+// their own pass) because of the embedded affine maps.
+
+// CHECK-LABEL: func @fold_static_stride_subview_with_affine_load_store
+func.func @fold_static_stride_subview_with_affine_load_store(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 {
+ %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>>
+ %1 = affine.load %0[%arg3, %arg4] : memref<4x4xf32, strided<[64, 3], offset: ?>>
+ // CHECK-NEXT: affine.apply
+ // CHECK-NEXT: affine.apply
+ // CHECK-NEXT: affine.load
+ affine.store %1, %0[%arg3, %arg4] : memref<4x4xf32, strided<[64, 3], offset: ?>>
+ // CHECK-NEXT: affine.apply
+ // CHECK-NEXT: affine.apply
+ // CHECK-NEXT: affine.store
+ // CHECK-NEXT: return
+ return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
+// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32 {
+func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> f32 {
+ %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 6, 32] : memref<12x32xf32> into memref<2x6x32xf32>
+ %1 = affine.load %0[%arg1, %arg2, %arg3] : memref<2x6x32xf32>
+ return %1 : f32
+}
+// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]]] by (2, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG3]]] : memref<12x32xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+// -----
+
+// CHECK-LABEL: @fold_static_stride_subview_with_affine_load_store_collapse_shape
+// CHECK-SAME: (%[[ARG0:.*]]: memref<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg0 : memref<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
+ %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<2x6x32xf32> into memref<12x32xf32>
+ %1 = affine.load %0[%arg1, %arg2] : memref<12x32xf32>
+ return %1 : f32
+}
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (2, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<2x6x32xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+// -----
+
+// CHECK-LABEL: @fold_dynamic_size_collapse_shape_with_affine_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
+ %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<?x6x32xf32> into memref<?x32xf32>
+ %1 = affine.load %0[%arg1, %arg2] : memref<?x32xf32>
+ return %1 : f32
+}
+// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x6x32xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+// -----
+
+// CHECK-LABEL: @fold_fully_dynamic_size_collapse_shape_with_affine_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @fold_fully_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x?x?xf32>, %arg1 : index, %arg2 : index) -> f32 {
+ %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<?x?x?xf32> into memref<?x?xf32>
+ %1 = affine.load %0[%arg1, %arg2] : memref<?x?xf32>
+ return %1 : f32
+}
+// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, %[[SIZES]]#1)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x?x?xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+
+// -----
+
+// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_3d
+// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) -> f32 {
+func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: index) -> f32 {
+ %0 = memref.expand_shape %arg0 [[0, 1, 2], [3]] output_shape [2, 2, 3, 32] : memref<12x32xf32> into memref<2x2x3x32xf32>
+ %1 = affine.load %0[%arg1, %arg2, %arg3, %arg4] : memref<2x2x3x32xf32>
+ return %1 : f32
+}
+// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]], %[[ARG3]]] by (2, 2, 3)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG4]]] : memref<12x32xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
+// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
+func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
+ %subview = memref.subview %alloc[%c5, 0] [%c10, 16] [1, 1] : memref<2048x16xf32> to memref<?x16xf32, strided<[16, 1], offset: ?>>
+ %expand_shape = memref.expand_shape %subview [[0], [1, 2, 3]] output_shape [%sz0, 1, 8, 2] : memref<?x16xf32, strided<[16, 1], offset: ?>> into memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
+ %dim = memref.dim %expand_shape, %c0 : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
+
+ affine.for %arg6 = 0 to %dim step 64 {
+ affine.for %arg7 = 0 to 16 step 16 {
+ %dummy_load = affine.load %expand_shape[%arg6, 0, %arg7, %arg7] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
+ affine.store %dummy_load, %subview[%arg6, %arg7] : memref<?x16xf32, strided<[16, 1], offset: ?>>
+ }
+ }
+ return
+}
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0
+// CHECK-NEXT: memref.subview
+// CHECK-NEXT: %[[EXPAND_SHAPE:.*]] = memref.expand_shape
+// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
+// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to %[[DIM]] step 64 {
+// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 16 step 16 {
+// CHECK-NEXT: %[[VAL0:.*]] = affine.linearize_index disjoint [%[[C0]], %[[ARG6]], %[[ARG6]]] by (1, 8, 2)
+// CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])[%[[ARG2]]]
+// CHECK-NEXT: %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL1]], %[[VAL0]]] : memref<2048x16xf32>
+// CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])[%[[ARG2]]]
+// CHECK-NEXT: affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG6]]] : memref<2048x16xf32>
+
+// -----
+
+// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_loops
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
+func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_loops(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
+ %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 1024, 1024, 1] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
+ affine.for %arg3 = 0 to 1 {
+ affine.for %arg4 = 0 to 1024 {
+ affine.for %arg5 = 0 to 1020 {
+ affine.for %arg6 = 0 to 1 {
+ %1 = affine.load %0[%arg3, %arg4, %arg5, %arg6] : memref<1x1024x1024x1xf32>
+ affine.store %1, %arg1[%arg2] : memref<1xf32>
+ }
+ }
+ }
+ }
+ %2 = affine.load %arg1[%arg2] : memref<1xf32>
+ return %2 : f32
+}
+// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 1 {
+// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
+// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
+// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
+// CHECK-NEXT: %[[IDX1:.*]] = affine.linearize_index disjoint [%[[ARG3]], %[[ARG4]]] by (1, 1024)
+// CHECK-NEXT: %[[IDX2:.*]] = affine.linearize_index disjoint [%[[ARG5]], %[[ARG6]]] by (1024, 1)
+// CHECK-NEXT: affine.load %[[ARG0]][%[[IDX1]], %[[IDX2]]] : memref<1024x1024xf32>
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
+func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
+ %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 1024, 1024, 1] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
+ affine.for %arg3 = 0 to 1 {
+ affine.for %arg4 = 0 to 1024 {
+ affine.for %arg5 = 0 to 1020 {
+ affine.for %arg6 = 0 to 1 {
+ %1 = affine.load %0[%arg3, %arg4 + %arg3, %arg5, %arg6] : memref<1x1024x1024x1xf32>
+ affine.store %1, %arg1[%arg2] : memref<1xf32>
+ }
+ }
+ }
+ }
+ %2 = affine.load %arg1[%arg2] : memref<1xf32>
+ return %2 : f32
+}
+// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 1 {
+// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
+// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
+// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
+// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])
+// CHECK-NEXT: %[[TMP2:.*]] = affine.linearize_index disjoint [%[[ARG3]], %[[TMP1]]] by (1, 1024)
+// CHECK-NEXT: %[[TMP3:.*]] = affine.linearize_index disjoint [%[[ARG5]], %[[ARG6]]] by (1024, 1)
+// CHECK-NEXT: affine.load %[[ARG0]][%[[TMP2]], %[[TMP3]]] : memref<1024x1024xf32>
+
+// -----
+
+// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_collapse_shape_with_0d_result
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1xf32>, %[[ARG1:.*]]: memref<1xf32>)
+func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape_with_0d_result(%arg0: memref<1xf32>, %arg1: memref<1xf32>) -> memref<1xf32> {
+ %0 = memref.collapse_shape %arg0 [] : memref<1xf32> into memref<f32>
+ affine.for %arg2 = 0 to 3 {
+ %1 = affine.load %0[] : memref<f32>
+ affine.store %1, %arg1[0] : memref<1xf32>
+ }
+ return %arg1 : memref<1xf32>
+}
+// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : index
+// CHECK-NEXT: affine.for %{{.*}} = 0 to 3 {
+// CHECK-NEXT: affine.load %[[ARG0]][%[[ZERO]]] : memref<1xf32>
+
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index ca91b0141f593..93e5ba462584a 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -387,95 +387,6 @@ func.func @fold_masked_vector_transfer_write_with_rank_reducing_subview(
// -----
-// Test with affine.load/store ops. We only do a basic test here since the
-// logic is identical to that with memref.load/store ops. The same affine.apply
-// ops would be generated.
-
-// CHECK-LABEL: func @fold_static_stride_subview_with_affine_load_store
-func.func @fold_static_stride_subview_with_affine_load_store(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 {
- %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>>
- %1 = affine.load %0[%arg3, %arg4] : memref<4x4xf32, strided<[64, 3], offset: ?>>
- // CHECK-NEXT: affine.apply
- // CHECK-NEXT: affine.apply
- // CHECK-NEXT: affine.load
- affine.store %1, %0[%arg3, %arg4] : memref<4x4xf32, strided<[64, 3], offset: ?>>
- // CHECK-NEXT: affine.apply
- // CHECK-NEXT: affine.apply
- // CHECK-NEXT: affine.store
- // CHECK-NEXT: return
- return %1 : f32
-}
-
-// -----
-
-// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
-// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32 {
-func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> f32 {
- %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 6, 32] : memref<12x32xf32> into memref<2x6x32xf32>
- %1 = affine.load %0[%arg1, %arg2, %arg3] : memref<2x6x32xf32>
- return %1 : f32
-}
-// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]]] by (2, 6)
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG3]]] : memref<12x32xf32>
-// CHECK-NEXT: return %[[RESULT]] : f32
-
-// -----
-
-// CHECK-LABEL: @fold_static_stride_subview_with_affine_load_store_collapse_shape
-// CHECK-SAME: (%[[ARG0:.*]]: memref<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
-func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg0 : memref<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
- %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<2x6x32xf32> into memref<12x32xf32>
- %1 = affine.load %0[%arg1, %arg2] : memref<12x32xf32>
- return %1 : f32
-}
-// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (2, 6)
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<2x6x32xf32>
-// CHECK-NEXT: return %[[RESULT]] : f32
-
-// -----
-
-// CHECK-LABEL: @fold_dynamic_size_collapse_shape_with_affine_load
-// CHECK-SAME: (%[[ARG0:.*]]: memref<?x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
-func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
- %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<?x6x32xf32> into memref<?x32xf32>
- %1 = affine.load %0[%arg1, %arg2] : memref<?x32xf32>
- return %1 : f32
-}
-// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
-// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, 6)
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x6x32xf32>
-// CHECK-NEXT: return %[[RESULT]] : f32
-
-// -----
-
-// CHECK-LABEL: @fold_fully_dynamic_size_collapse_shape_with_affine_load
-// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
-func.func @fold_fully_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x?x?xf32>, %arg1 : index, %arg2 : index) -> f32 {
- %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<?x?x?xf32> into memref<?x?xf32>
- %1 = affine.load %0[%arg1, %arg2] : memref<?x?xf32>
- return %1 : f32
-}
-// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
-// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, %[[SIZES]]#1)
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x?x?xf32>
-// CHECK-NEXT: return %[[RESULT]] : f32
-
-
-// -----
-
-// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_3d
-// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) -> f32 {
-func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: index) -> f32 {
- %0 = memref.expand_shape %arg0 [[0, 1, 2], [3]] output_shape [2, 2, 3, 32] : memref<12x32xf32> into memref<2x2x3x32xf32>
- %1 = affine.load %0[%arg1, %arg2, %arg3, %arg4] : memref<2x2x3x32xf32>
- return %1 : f32
-}
-// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]], %[[ARG3]]] by (2, 2, 3)
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG4]]] : memref<12x32xf32>
-// CHECK-NEXT: return %[[RESULT]] : f32
-
-// -----
-
// CHECK-LABEL: fold_dynamic_subview_with_memref_load_expand_shape
// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32
func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0: index) -> f32 {
@@ -510,95 +421,9 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
// -----
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
-// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
-// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
-func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
- %subview = memref.subview %alloc[%c5, 0] [%c10, 16] [1, 1] : memref<2048x16xf32> to memref<?x16xf32, strided<[16, 1], offset: ?>>
- %expand_shape = memref.expand_shape %subview [[0], [1, 2, 3]] output_shape [%sz0, 1, 8, 2] : memref<?x16xf32, strided<[16, 1], offset: ?>> into memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
- %dim = memref.dim %expand_shape, %c0 : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
-
- affine.for %arg6 = 0 to %dim step 64 {
- affine.for %arg7 = 0 to 16 step 16 {
- %dummy_load = affine.load %expand_shape[%arg6, 0, %arg7, %arg7] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
- affine.store %dummy_load, %subview[%arg6, %arg7] : memref<?x16xf32, strided<[16, 1], offset: ?>>
- }
- }
- return
-}
-// CHECK-NEXT: %[[C0:.*]] = arith.constant 0
-// CHECK-NEXT: memref.subview
-// CHECK-NEXT: %[[EXPAND_SHAPE:.*]] = memref.expand_shape
-// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
-// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to %[[DIM]] step 64 {
-// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 16 step 16 {
-// CHECK-NEXT: %[[VAL0:.*]] = affine.linearize_index disjoint [%[[C0]], %[[ARG6]], %[[ARG6]]] by (1, 8, 2)
-// CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])[%[[ARG2]]]
-// CHECK-NEXT: %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL1]], %[[VAL0]]] : memref<2048x16xf32>
-// CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])[%[[ARG2]]]
-// CHECK-NEXT: affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG6]]] : memref<2048x16xf32>
-
-// -----
-
-// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
-// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
-func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
- %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 1024, 1024, 1] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
- affine.for %arg3 = 0 to 1 {
- affine.for %arg4 = 0 to 1024 {
- affine.for %arg5 = 0 to 1020 {
- affine.for %arg6 = 0 to 1 {
- %1 = affine.load %0[%arg3, %arg4, %arg5, %arg6] : memref<1x1024x1024x1xf32>
- affine.store %1, %arg1[%arg2] : memref<1xf32>
- }
- }
- }
- }
- %2 = affine.load %arg1[%arg2] : memref<1xf32>
- return %2 : f32
-}
-// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 1 {
-// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
-// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
-// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
-// CHECK-NEXT: %[[IDX1:.*]] = affine.linearize_index disjoint [%[[ARG3]], %[[ARG4]]] by (1, 1024)
-// CHECK-NEXT: %[[IDX2:.*]] = affine.linearize_index disjoint [%[[ARG5]], %[[ARG6]]] by (1024, 1)
-// CHECK-NEXT: affine.load %[[ARG0]][%[[IDX1]], %[[IDX2]]] : memref<1024x1024xf32>
-
-// -----
-
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
-// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression
-// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
-func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
- %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 1024, 1024, 1] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
- affine.for %arg3 = 0 to 1 {
- affine.for %arg4 = 0 to 1024 {
- affine.for %arg5 = 0 to 1020 {
- affine.for %arg6 = 0 to 1 {
- %1 = affine.load %0[%arg3, %arg4 + %arg3, %arg5, %arg6] : memref<1x1024x1024x1xf32>
- affine.store %1, %arg1[%arg2] : memref<1xf32>
- }
- }
- }
- }
- %2 = affine.load %arg1[%arg2] : memref<1xf32>
- return %2 : f32
-}
-// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 1 {
-// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
-// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
-// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
-// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])
-// CHECK-NEXT: %[[TMP2:.*]] = affine.linearize_index disjoint [%[[ARG3]], %[[TMP1]]] by (1, 1024)
-// CHECK-NEXT: %[[TMP3:.*]] = affine.linearize_index disjoint [%[[ARG5]], %[[ARG6]]] by (1024, 1)
-// CHECK-NEXT: affine.load %[[ARG0]][%[[TMP2]], %[[TMP3]]] : memref<1024x1024xf32>
-
-// -----
-
-// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index
+// CHECK-LABEL: fold_static_stride_subview_with_memref_expand_shape_with_constant_access_index
// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
-func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
+func.func @fold_static_stride_subview_with_memref_expand_shape_with_constant_access_index(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
%0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 1024, 1024, 1] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
%cst = arith.constant 0 : index
affine.for %arg3 = 0 to 1 {
@@ -625,22 +450,6 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_c
// -----
-// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_collapse_shape_with_0d_result
-// CHECK-SAME: (%[[ARG0:.*]]: memref<1xf32>, %[[ARG1:.*]]: memref<1xf32>)
-func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape_with_0d_result(%arg0: memref<1xf32>, %arg1: memref<1xf32>) -> memref<1xf32> {
- %0 = memref.collapse_shape %arg0 [] : memref<1xf32> into memref<f32>
- affine.for %arg2 = 0 to 3 {
- %1 = affine.load %0[] : memref<f32>
- affine.store %1, %arg1[0] : memref<1xf32>
- }
- return %arg1 : memref<1xf32>
-}
-// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : index
-// CHECK-NEXT: affine.for %{{.*}} = 0 to 3 {
-// CHECK-NEXT: affine.load %[[ARG0]][%[[ZERO]]] : memref<1xf32>
-
-// -----
-
// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 2)>
// CHECK-LABEL: func @subview_of_subview(
// CHECK-SAME: %[[m:.*]]: memref<8x1024xf32, 3>, %[[pos:.*]]: index
More information about the Mlir-commits
mailing list