[Mlir-commits] [mlir] 003b28d - [mlir] Move affine's FoldMemRefAliasOps into its own pass (#172548)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 2 10:13:46 PST 2026


Author: Krzysztof Drewniak
Date: 2026-01-02T10:13:42-08:00
New Revision: 003b28d0310dcb89e9c537a5d431c1d7b0d492f3

URL: https://github.com/llvm/llvm-project/commit/003b28d0310dcb89e9c537a5d431c1d7b0d492f3
DIFF: https://github.com/llvm/llvm-project/commit/003b28d0310dcb89e9c537a5d431c1d7b0d492f3.diff

LOG: [mlir] Move affine's FoldMemRefAliasOps into its own pass (#172548)

I'm planning to introduce an interface that'll allow FoldMemRefAliasOps
to not know about dialects like NVVM or GPU. To do this, however, I need
to get the `affine` ops (which need special handling in order to handle
their implicit affine maps) into a separate pass, analogously to how
`amdgpu` ops have these patterns under their dialect and ton under
`memref`.

This commit also changes the expand/collapse_shape index resolvers to
return `void`, since they never actually failed and to make it clearer
that they modify IR.

(Note: An LLM did the initial refactoring and test movement, I've
reviewed the results and edited them some.)

Added: 
    mlir/lib/Dialect/Affine/Transforms/FoldMemRefAliasOps.cpp
    mlir/test/Dialect/Affine/fold-memref-alias-ops.mlir

Modified: 
    mlir/include/mlir/Dialect/Affine/Passes.h
    mlir/include/mlir/Dialect/Affine/Passes.td
    mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
    mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
    mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
    mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
    mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
    mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
    mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h
index ec349ec48e33b..58596733a76c0 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.h
+++ b/mlir/include/mlir/Dialect/Affine/Passes.h
@@ -19,6 +19,7 @@
 #include <limits>
 
 namespace mlir {
+class RewritePatternSet;
 
 namespace func {
 class FuncOp;
@@ -126,6 +127,10 @@ std::unique_ptr<Pass> createAffineExpandIndexOpsPass();
 /// operations.
 std::unique_ptr<Pass> createAffineExpandIndexOpsAsAffinePass();
 
+/// Appends patterns for folding memref aliasing ops into affine load/store
+/// ops into `patterns`.
+void populateAffineFoldMemRefAliasOpPatterns(RewritePatternSet &patterns);
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td
index 1ad5344121d4e..430edffc29038 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Passes.td
@@ -440,4 +440,15 @@ def AffineExpandIndexOpsAsAffine : Pass<"affine-expand-index-ops-as-affine"> {
   let constructor = "mlir::affine::createAffineExpandIndexOpsAsAffinePass()";
 }
 
+def AffineFoldMemRefAliasOps : Pass<"affine-fold-memref-alias-ops"> {
+  let summary = "Fold memref alias ops into affine memory ops";
+  let description = [{
+    The pass folds memref.subview, memref.expand_shape, and memref.collapse_shape
+    operations into affine memory operations (currently only `affine.load` and
+    `affine.store`) . This is similar to the `fold-memref-alias-ops` pass in the
+    `memref` dialect but adds handling specific to affine operations.
+  }];
+  let dependentDialects = ["memref::MemRefDialect"];
+}
+
 #endif // MLIR_DIALECT_AFFINE_PASSES

diff  --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index c403386bd214a..d04ae101fe1dc 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -22,7 +22,7 @@ def FoldMemRefAliasOpsPass : Pass<"fold-memref-alias-ops"> {
     from/to the original memref.
   }];
   let dependentDialects = [
-      "affine::AffineDialect", "memref::MemRefDialect", "vector::VectorDialect"
+      "memref::MemRefDialect", "vector::VectorDialect"
   ];
 }
 
@@ -197,7 +197,7 @@ def ReifyResultShapesPass : Pass<"reify-result-shapes"> {
   let description = [{
     This pass reifies the shapes of a subset of `ReifyRankedShapedTypeOpInterface`
     ops with `tensor` results.
-    
+
     The pass currently only supports result shape type reification for:
       - tensor::PadOp
       - tensor::ConcatOp

diff  --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
index dd3b3dea6ef26..5d2429bb476e6 100644
--- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
+++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
@@ -118,7 +118,7 @@ MemrefValue skipViewLikeOps(MemrefValue source);
 
 /// Given the 'indices' of a load/store operation where the memref is a result
 /// of a expand_shape op, returns the indices w.r.t to the source memref of the
-/// expand_shape op. For example
+/// expand_shape op into `sourceIndices`. For example
 ///
 /// %0 = ... : memref<12x42xf32>
 /// %1 = memref.expand_shape %0 [[0, 1], [2]]
@@ -129,14 +129,19 @@ MemrefValue skipViewLikeOps(MemrefValue source);
 ///
 /// %2 = load %0[6 * i1 + i2, %i3] :
 ///          memref<12x42xf32>
-LogicalResult resolveSourceIndicesExpandShape(
-    Location loc, PatternRewriter &rewriter,
-    memref::ExpandShapeOp expandShapeOp, ValueRange indices,
-    SmallVectorImpl<Value> &sourceIndices, bool startsInbounds);
+///
+/// If `startsInbounds` is true, optimizations that rely on all indices being
+/// non-negative and less than the corresponding memref dimension may be
+/// performed.
+void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
+                                     memref::ExpandShapeOp expandShapeOp,
+                                     ValueRange indices,
+                                     SmallVectorImpl<Value> &sourceIndices,
+                                     bool startsInbounds);
 
 /// Given the 'indices' of a load/store operation where the memref is a result
 /// of a collapse_shape op, returns the indices w.r.t to the source memref of
-/// the collapse_shape op. For example
+/// the collapse_shape op, returing them into `sourceIndices`. For example
 ///
 /// %0 = ... : memref<2x6x42xf32>
 /// %1 = memref.collapse_shape %0 [[0, 1], [2]]
@@ -147,11 +152,10 @@ LogicalResult resolveSourceIndicesExpandShape(
 ///
 /// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
 ///          memref<2x6x42xf32>
-LogicalResult
-resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
-                                  memref::CollapseShapeOp collapseShapeOp,
-                                  ValueRange indices,
-                                  SmallVectorImpl<Value> &sourceIndices);
+void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
+                                       memref::CollapseShapeOp collapseShapeOp,
+                                       ValueRange indices,
+                                       SmallVectorImpl<Value> &sourceIndices);
 
 } // namespace memref
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
index d54751098410b..fb2b096df9c3d 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
@@ -46,21 +46,15 @@ static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc,
         return success();
       })
       .Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
-        if (failed(mlir::memref::resolveSourceIndicesExpandShape(
-                loc, rewriter, expandShapeOp, indices, resolvedIndices,
-                false))) {
-          return failure();
-        }
+        mlir::memref::resolveSourceIndicesExpandShape(
+            loc, rewriter, expandShapeOp, indices, resolvedIndices, false);
         memrefBase = expandShapeOp.getViewSource();
         return success();
       })
       .Case<memref::CollapseShapeOp>(
           [&](memref::CollapseShapeOp collapseShapeOp) {
-            if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
-                    loc, rewriter, collapseShapeOp, indices,
-                    resolvedIndices))) {
-              return failure();
-            }
+            mlir::memref::resolveSourceIndicesCollapseShape(
+                loc, rewriter, collapseShapeOp, indices, resolvedIndices);
             memrefBase = collapseShapeOp.getViewSource();
             return success();
           })

diff  --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
index c792200f4a49a..7bce124817032 100644
--- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
   AffineParallelize.cpp
   AffineScalarReplacement.cpp
   DecomposeAffineOps.cpp
+  FoldMemRefAliasOps.cpp
   LoopCoalescing.cpp
   LoopFusion.cpp
   LoopTiling.cpp
@@ -34,6 +35,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
   MLIRArithDialect
   MLIRIR
   MLIRMemRefDialect
+  MLIRMemRefUtils
   MLIRPass
   MLIRSCFUtils
   MLIRSideEffectInterfaces

diff  --git a/mlir/lib/Dialect/Affine/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/Affine/Transforms/FoldMemRefAliasOps.cpp
new file mode 100644
index 0000000000000..febdb94970431
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/Transforms/FoldMemRefAliasOps.cpp
@@ -0,0 +1,253 @@
+//===- FoldMemRefAliasOps.cpp - Fold memref alias ops for affine ops ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass contains affine-specif versions of the folding patterns for
+// memref.expand_shape, memref.collapse_shape, and memref.subview, since
+// those all need affine-specific handling that won't fit a general interface.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Passes.h"
+#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace affine {
+#define GEN_PASS_DEF_AFFINEFOLDMEMREFALIASOPS
+#include "mlir/Dialect/Affine/Passes.h.inc"
+} // namespace affine
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::affine;
+
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+/// Given an AffineMap and a list of indices, apply the map to get the
+/// underlying indices (expanding the affine map).
+static void expandToUnderlyingIndices(AffineMap affineMap, ValueRange indices,
+                                      Location loc, PatternRewriter &rewriter,
+                                      SmallVectorImpl<Value> &result) {
+  SmallVector<OpFoldResult> indicesOfr(
+      llvm::map_to_vector(indices, [](Value v) -> OpFoldResult { return v; }));
+  for (unsigned i : llvm::seq(0u, affineMap.getNumResults())) {
+    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, affineMap.getSubMap({i}), indicesOfr);
+    result.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct AffineLoadOpOfSubViewOpFolder final : OpRewritePattern<AffineLoadOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(AffineLoadOp loadOp,
+                                PatternRewriter &rewriter) const override {
+    auto subViewOp = loadOp.getMemref().getDefiningOp<memref::SubViewOp>();
+
+    if (!subViewOp)
+      return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
+
+    SmallVector<Value> indices;
+    expandToUnderlyingIndices(loadOp.getAffineMap(), loadOp.getIndices(),
+                              loadOp.getLoc(), rewriter, indices);
+
+    SmallVector<Value> sourceIndices;
+    affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+        rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
+        subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
+        sourceIndices);
+
+    rewriter.replaceOpWithNewOp<AffineLoadOp>(loadOp, subViewOp.getSource(),
+                                              sourceIndices);
+    return success();
+  }
+};
+
+struct AffineLoadOpOfExpandShapeOpFolder final
+    : OpRewritePattern<AffineLoadOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(AffineLoadOp loadOp,
+                                PatternRewriter &rewriter) const override {
+    auto expandShapeOp =
+        loadOp.getMemref().getDefiningOp<memref::ExpandShapeOp>();
+
+    if (!expandShapeOp)
+      return failure();
+
+    SmallVector<Value> indices;
+    expandToUnderlyingIndices(loadOp.getAffineMap(), loadOp.getIndices(),
+                              loadOp.getLoc(), rewriter, indices);
+
+    SmallVector<Value> sourceIndices;
+    // affine.load guarantees that indexes start inbounds, which impacts if our
+    // linearization is `disjoint`.
+    memref::resolveSourceIndicesExpandShape(
+        loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
+        /*startsInbounds=*/true);
+
+    rewriter.replaceOpWithNewOp<AffineLoadOp>(
+        loadOp, expandShapeOp.getViewSource(), sourceIndices);
+    return success();
+  }
+};
+
+struct AffineLoadOpOfCollapseShapeOpFolder final
+    : OpRewritePattern<AffineLoadOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(AffineLoadOp loadOp,
+                                PatternRewriter &rewriter) const override {
+    auto collapseShapeOp =
+        loadOp.getMemref().getDefiningOp<memref::CollapseShapeOp>();
+
+    if (!collapseShapeOp)
+      return failure();
+
+    SmallVector<Value> indices;
+    expandToUnderlyingIndices(loadOp.getAffineMap(), loadOp.getIndices(),
+                              loadOp.getLoc(), rewriter, indices);
+
+    SmallVector<Value> sourceIndices;
+    memref::resolveSourceIndicesCollapseShape(
+        loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices);
+
+    rewriter.replaceOpWithNewOp<AffineLoadOp>(
+        loadOp, collapseShapeOp.getViewSource(), sourceIndices);
+    return success();
+  }
+};
+
+struct AffineStoreOpOfSubViewOpFolder final : OpRewritePattern<AffineStoreOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(AffineStoreOp storeOp,
+                                PatternRewriter &rewriter) const override {
+    auto subViewOp = storeOp.getMemref().getDefiningOp<memref::SubViewOp>();
+
+    if (!subViewOp)
+      return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
+
+    // For affine ops, we need to apply the map to get the "actual" indices.
+    SmallVector<Value> indices;
+    expandToUnderlyingIndices(storeOp.getAffineMap(), storeOp.getIndices(),
+                              storeOp.getLoc(), rewriter, indices);
+
+    SmallVector<Value> sourceIndices;
+    affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+        rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
+        subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
+        sourceIndices);
+
+    rewriter.replaceOpWithNewOp<AffineStoreOp>(
+        storeOp, storeOp.getValue(), subViewOp.getSource(), sourceIndices);
+    return success();
+  }
+};
+
+struct AffineStoreOpOfExpandShapeOpFolder final
+    : OpRewritePattern<AffineStoreOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(AffineStoreOp storeOp,
+                                PatternRewriter &rewriter) const override {
+    auto expandShapeOp =
+        storeOp.getMemref().getDefiningOp<memref::ExpandShapeOp>();
+
+    if (!expandShapeOp)
+      return failure();
+
+    SmallVector<Value> indices;
+    expandToUnderlyingIndices(storeOp.getAffineMap(), storeOp.getIndices(),
+                              storeOp.getLoc(), rewriter, indices);
+
+    SmallVector<Value> sourceIndices;
+    // affine.store guarantees that indexes start inbounds, which impacts if our
+    // linearization is `disjoint`.
+    memref::resolveSourceIndicesExpandShape(
+        storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
+        /*startsInbounds=*/true);
+
+    rewriter.replaceOpWithNewOp<AffineStoreOp>(
+        storeOp, storeOp.getValueToStore(), expandShapeOp.getViewSource(),
+        sourceIndices);
+    return success();
+  }
+};
+
+struct AffineStoreOpOfCollapseShapeOpFolder final
+    : OpRewritePattern<AffineStoreOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(AffineStoreOp storeOp,
+                                PatternRewriter &rewriter) const override {
+    auto collapseShapeOp =
+        storeOp.getMemref().getDefiningOp<memref::CollapseShapeOp>();
+
+    if (!collapseShapeOp)
+      return failure();
+
+    // For affine ops, we need to apply the map to get the "actual" indices.
+    SmallVector<Value> indices;
+    expandToUnderlyingIndices(storeOp.getAffineMap(), storeOp.getIndices(),
+                              storeOp.getLoc(), rewriter, indices);
+
+    SmallVector<Value> sourceIndices;
+    memref::resolveSourceIndicesCollapseShape(
+        storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices);
+
+    rewriter.replaceOpWithNewOp<AffineStoreOp>(
+        storeOp, storeOp.getValueToStore(), collapseShapeOp.getViewSource(),
+        sourceIndices);
+    return success();
+  }
+};
+
+} // namespace
+
+void affine::populateAffineFoldMemRefAliasOpPatterns(
+    RewritePatternSet &patterns) {
+  patterns
+      .add<AffineLoadOpOfSubViewOpFolder, AffineLoadOpOfExpandShapeOpFolder,
+           AffineLoadOpOfCollapseShapeOpFolder, AffineStoreOpOfSubViewOpFolder,
+           AffineStoreOpOfExpandShapeOpFolder,
+           AffineStoreOpOfCollapseShapeOpFolder>(patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// Pass registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct AffineFoldMemRefAliasOpsPass final
+    : public affine::impl::AffineFoldMemRefAliasOpsBase<
+          AffineFoldMemRefAliasOpsPass> {
+  void runOnOperation() override;
+};
+
+} // namespace
+
+void AffineFoldMemRefAliasOpsPass::runOnOperation() {
+  RewritePatternSet patterns(&getContext());
+  affine::populateAffineFoldMemRefAliasOpPatterns(patterns);
+  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+}

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 3667fdb2bb728..3cacb7e29263b 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -11,7 +11,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -198,22 +197,6 @@ class NVGPUAsyncCopyOpSubViewOpFolder final
 };
 } // namespace
 
-static SmallVector<Value>
-calculateExpandedAccessIndices(AffineMap affineMap,
-                               const SmallVector<Value> &indices, Location loc,
-                               PatternRewriter &rewriter) {
-  SmallVector<OpFoldResult> indicesOfr(llvm::to_vector(
-      llvm::map_range(indices, [](Value v) -> OpFoldResult { return v; })));
-  SmallVector<Value> expandedIndices;
-  for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++) {
-    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-        rewriter, loc, affineMap.getSubMap({i}), indicesOfr);
-    expandedIndices.push_back(
-        getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
-  }
-  return expandedIndices;
-}
-
 template <typename XferOp>
 static LogicalResult
 preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp,
@@ -262,28 +245,13 @@ LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
   if (failed(preconditionResult))
     return preconditionResult;
 
-  SmallVector<Value> indices(loadOp.getIndices().begin(),
-                             loadOp.getIndices().end());
-  // For affine ops, we need to apply the map to get the operands to get the
-  // "actual" indices.
-  if (auto affineLoadOp =
-          dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
-    AffineMap affineMap = affineLoadOp.getAffineMap();
-    auto expandedIndices = calculateExpandedAccessIndices(
-        affineMap, indices, loadOp.getLoc(), rewriter);
-    indices.assign(expandedIndices.begin(), expandedIndices.end());
-  }
   SmallVector<Value> sourceIndices;
   affine::resolveIndicesIntoOpWithOffsetsAndStrides(
       rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
-      subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
-      sourceIndices);
+      subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
+      loadOp.getIndices(), sourceIndices);
 
   llvm::TypeSwitch<Operation *, void>(loadOp)
-      .Case([&](affine::AffineLoadOp op) {
-        rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
-            loadOp, subViewOp.getSource(), sourceIndices);
-      })
       .Case([&](memref::LoadOp op) {
         rewriter.replaceOpWithNewOp<memref::LoadOp>(
             loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
@@ -328,32 +296,14 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
   if (!expandShapeOp)
     return failure();
 
-  SmallVector<Value> indices(loadOp.getIndices().begin(),
-                             loadOp.getIndices().end());
-  // For affine ops, we need to apply the map to get the operands to get the
-  // "actual" indices.
-  if (auto affineLoadOp =
-          dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
-    AffineMap affineMap = affineLoadOp.getAffineMap();
-    auto expandedIndices = calculateExpandedAccessIndices(
-        affineMap, indices, loadOp.getLoc(), rewriter);
-    indices.assign(expandedIndices.begin(), expandedIndices.end());
-  }
   SmallVector<Value> sourceIndices;
-  // memref.load and affine.load guarantee that indexes start inbounds
-  // while the vector operations don't. This impacts if our linearization
-  // is `disjoint`
-  if (failed(resolveSourceIndicesExpandShape(
-          loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
-          isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
-    return failure();
+  // memref.load guarantees that indexes start inbounds while the vector
+  // operations don't. This impacts if our linearization is `disjoint`
+  resolveSourceIndicesExpandShape(loadOp.getLoc(), rewriter, expandShapeOp,
+                                  loadOp.getIndices(), sourceIndices,
+                                  isa<memref::LoadOp>(loadOp.getOperation()));
 
   return llvm::TypeSwitch<Operation *, LogicalResult>(loadOp)
-      .Case([&](affine::AffineLoadOp op) {
-        rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
-            loadOp, expandShapeOp.getViewSource(), sourceIndices);
-        return success();
-      })
       .Case([&](memref::LoadOp op) {
         rewriter.replaceOpWithNewOp<memref::LoadOp>(
             loadOp, expandShapeOp.getViewSource(), sourceIndices,
@@ -407,26 +357,10 @@ LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
   if (!collapseShapeOp)
     return failure();
 
-  SmallVector<Value> indices(loadOp.getIndices().begin(),
-                             loadOp.getIndices().end());
-  // For affine ops, we need to apply the map to get the operands to get the
-  // "actual" indices.
-  if (auto affineLoadOp =
-          dyn_cast<affine::AffineLoadOp>(loadOp.getOperation())) {
-    AffineMap affineMap = affineLoadOp.getAffineMap();
-    auto expandedIndices = calculateExpandedAccessIndices(
-        affineMap, indices, loadOp.getLoc(), rewriter);
-    indices.assign(expandedIndices.begin(), expandedIndices.end());
-  }
   SmallVector<Value> sourceIndices;
-  if (failed(resolveSourceIndicesCollapseShape(
-          loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
-    return failure();
+  resolveSourceIndicesCollapseShape(loadOp.getLoc(), rewriter, collapseShapeOp,
+                                    loadOp.getIndices(), sourceIndices);
   llvm::TypeSwitch<Operation *, void>(loadOp)
-      .Case([&](affine::AffineLoadOp op) {
-        rewriter.replaceOpWithNewOp<affine::AffineLoadOp>(
-            loadOp, collapseShapeOp.getViewSource(), sourceIndices);
-      })
       .Case([&](memref::LoadOp op) {
         rewriter.replaceOpWithNewOp<memref::LoadOp>(
             loadOp, collapseShapeOp.getViewSource(), sourceIndices,
@@ -460,28 +394,13 @@ LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
   if (failed(preconditionResult))
     return preconditionResult;
 
-  SmallVector<Value> indices(storeOp.getIndices().begin(),
-                             storeOp.getIndices().end());
-  // For affine ops, we need to apply the map to get the operands to get the
-  // "actual" indices.
-  if (auto affineStoreOp =
-          dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
-    AffineMap affineMap = affineStoreOp.getAffineMap();
-    auto expandedIndices = calculateExpandedAccessIndices(
-        affineMap, indices, storeOp.getLoc(), rewriter);
-    indices.assign(expandedIndices.begin(), expandedIndices.end());
-  }
   SmallVector<Value> sourceIndices;
   affine::resolveIndicesIntoOpWithOffsetsAndStrides(
       rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
-      subViewOp.getMixedStrides(), subViewOp.getDroppedDims(), indices,
-      sourceIndices);
+      subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
+      storeOp.getIndices(), sourceIndices);
 
   llvm::TypeSwitch<Operation *, void>(storeOp)
-      .Case([&](affine::AffineStoreOp op) {
-        rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
-            op, op.getValue(), subViewOp.getSource(), sourceIndices);
-      })
       .Case([&](memref::StoreOp op) {
         rewriter.replaceOpWithNewOp<memref::StoreOp>(
             op, op.getValue(), subViewOp.getSource(), sourceIndices,
@@ -522,31 +441,13 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
   if (!expandShapeOp)
     return failure();
 
-  SmallVector<Value> indices(storeOp.getIndices().begin(),
-                             storeOp.getIndices().end());
-  // For affine ops, we need to apply the map to get the operands to get the
-  // "actual" indices.
-  if (auto affineStoreOp =
-          dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
-    AffineMap affineMap = affineStoreOp.getAffineMap();
-    auto expandedIndices = calculateExpandedAccessIndices(
-        affineMap, indices, storeOp.getLoc(), rewriter);
-    indices.assign(expandedIndices.begin(), expandedIndices.end());
-  }
   SmallVector<Value> sourceIndices;
-  // memref.store and affine.store guarantee that indexes start inbounds
-  // while the vector operations don't. This impacts if our linearization
-  // is `disjoint`
-  if (failed(resolveSourceIndicesExpandShape(
-          storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
-          isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
-    return failure();
+  // memref.store guarantees that indexes start inbounds while the vector
+  // operations don't. This impacts if our linearization is `disjoint`
+  resolveSourceIndicesExpandShape(storeOp.getLoc(), rewriter, expandShapeOp,
+                                  storeOp.getIndices(), sourceIndices,
+                                  isa<memref::StoreOp>(storeOp.getOperation()));
   llvm::TypeSwitch<Operation *, void>(storeOp)
-      .Case([&](affine::AffineStoreOp op) {
-        rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
-            storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
-            sourceIndices);
-      })
       .Case([&](memref::StoreOp op) {
         rewriter.replaceOpWithNewOp<memref::StoreOp>(
             storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
@@ -575,27 +476,10 @@ LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
   if (!collapseShapeOp)
     return failure();
 
-  SmallVector<Value> indices(storeOp.getIndices().begin(),
-                             storeOp.getIndices().end());
-  // For affine ops, we need to apply the map to get the operands to get the
-  // "actual" indices.
-  if (auto affineStoreOp =
-          dyn_cast<affine::AffineStoreOp>(storeOp.getOperation())) {
-    AffineMap affineMap = affineStoreOp.getAffineMap();
-    auto expandedIndices = calculateExpandedAccessIndices(
-        affineMap, indices, storeOp.getLoc(), rewriter);
-    indices.assign(expandedIndices.begin(), expandedIndices.end());
-  }
   SmallVector<Value> sourceIndices;
-  if (failed(resolveSourceIndicesCollapseShape(
-          storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
-    return failure();
+  resolveSourceIndicesCollapseShape(storeOp.getLoc(), rewriter, collapseShapeOp,
+                                    storeOp.getIndices(), sourceIndices);
   llvm::TypeSwitch<Operation *, void>(storeOp)
-      .Case([&](affine::AffineStoreOp op) {
-        rewriter.replaceOpWithNewOp<affine::AffineStoreOp>(
-            storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
-            sourceIndices);
-      })
       .Case([&](memref::StoreOp op) {
         rewriter.replaceOpWithNewOp<memref::StoreOp>(
             storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
@@ -630,29 +514,27 @@ LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
                                                "source or destination");
 
   // If the source is a subview, we need to resolve the indices.
-  SmallVector<Value> srcindices(copyOp.getSrcIndices().begin(),
-                                copyOp.getSrcIndices().end());
-  SmallVector<Value> foldedSrcIndices(srcindices);
+  SmallVector<Value> foldedSrcIndices(copyOp.getSrcIndices().begin(),
+                                      copyOp.getSrcIndices().end());
 
   if (srcSubViewOp) {
     LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n");
     affine::resolveIndicesIntoOpWithOffsetsAndStrides(
         rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
         srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
-        srcindices, foldedSrcIndices);
+        copyOp.getSrcIndices(), foldedSrcIndices);
   }
 
   // If the destination is a subview, we need to resolve the indices.
-  SmallVector<Value> dstindices(copyOp.getDstIndices().begin(),
-                                copyOp.getDstIndices().end());
-  SmallVector<Value> foldedDstIndices(dstindices);
+  SmallVector<Value> foldedDstIndices(copyOp.getDstIndices().begin(),
+                                      copyOp.getDstIndices().end());
 
   if (dstSubViewOp) {
     LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n");
     affine::resolveIndicesIntoOpWithOffsetsAndStrides(
         rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
         dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
-        dstindices, foldedDstIndices);
+        copyOp.getDstIndices(), foldedDstIndices);
   }
 
   // Replace the copy op with a new copy op that uses the source and destination
@@ -669,33 +551,27 @@ LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
 }
 
 void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
-  patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>,
-               LoadOpOfSubViewOpFolder<memref::LoadOp>,
+  patterns.add<LoadOpOfSubViewOpFolder<memref::LoadOp>,
                LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
                LoadOpOfSubViewOpFolder<vector::LoadOp>,
                LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
                LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
                LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
-               StoreOpOfSubViewOpFolder<affine::AffineStoreOp>,
                StoreOpOfSubViewOpFolder<memref::StoreOp>,
                StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
                StoreOpOfSubViewOpFolder<vector::StoreOp>,
                StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
                StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
-               LoadOpOfExpandShapeOpFolder<affine::AffineLoadOp>,
                LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
                LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
                LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
                LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>,
-               StoreOpOfExpandShapeOpFolder<affine::AffineStoreOp>,
                StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
                StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
                StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
-               LoadOpOfCollapseShapeOpFolder<affine::AffineLoadOp>,
                LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
                LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
                LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
-               StoreOpOfCollapseShapeOpFolder<affine::AffineStoreOp>,
                StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
                StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
                StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,

diff  --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index e5486988947c6..2d341dce665e5 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -223,10 +223,11 @@ MemrefValue skipViewLikeOps(MemrefValue source) {
   return source;
 }
 
-LogicalResult resolveSourceIndicesExpandShape(
-    Location loc, PatternRewriter &rewriter,
-    memref::ExpandShapeOp expandShapeOp, ValueRange indices,
-    SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
+void resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
+                                     memref::ExpandShapeOp expandShapeOp,
+                                     ValueRange indices,
+                                     SmallVectorImpl<Value> &sourceIndices,
+                                     bool startsInbounds) {
   SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
 
   // Traverse all reassociation groups to determine the appropriate indices
@@ -246,14 +247,12 @@ LogicalResult resolveSourceIndicesExpandShape(
         rewriter, loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
     sourceIndices.push_back(collapsedIndex);
   }
-  return success();
 }
 
-LogicalResult
-resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
-                                  memref::CollapseShapeOp collapseShapeOp,
-                                  ValueRange indices,
-                                  SmallVectorImpl<Value> &sourceIndices) {
+void resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
+                                       memref::CollapseShapeOp collapseShapeOp,
+                                       ValueRange indices,
+                                       SmallVectorImpl<Value> &sourceIndices) {
   // Note: collapse_shape requires a strided memref, we can do this.
   auto metadata = memref::ExtractStridedMetadataOp::create(
       rewriter, loc, collapseShapeOp.getSrc());
@@ -285,7 +284,6 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
           getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
     }
   }
-  return success();
 }
 
 } // namespace memref

diff  --git a/mlir/test/Dialect/Affine/fold-memref-alias-ops.mlir b/mlir/test/Dialect/Affine/fold-memref-alias-ops.mlir
new file mode 100644
index 0000000000000..5e3e107531802
--- /dev/null
+++ b/mlir/test/Dialect/Affine/fold-memref-alias-ops.mlir
@@ -0,0 +1,191 @@
+// RUN: mlir-opt -affine-fold-memref-alias-ops -split-input-file %s | FileCheck %s
+
+// Tests for folding memref aliasing ops (expand/collapse_shape and subview) into
+// affine memory access operations, which require specific handling (and thus
+// their own pass) because of the embedded affine maps.
+
+// CHECK-LABEL: func @fold_static_stride_subview_with_affine_load_store
+func.func @fold_static_stride_subview_with_affine_load_store(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 {
+  %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>>
+  %1 = affine.load %0[%arg3, %arg4] : memref<4x4xf32, strided<[64, 3], offset: ?>>
+  // CHECK-NEXT: affine.apply
+  // CHECK-NEXT: affine.apply
+  // CHECK-NEXT: affine.load
+  affine.store %1, %0[%arg3, %arg4] : memref<4x4xf32, strided<[64, 3], offset: ?>>
+  // CHECK-NEXT: affine.apply
+  // CHECK-NEXT: affine.apply
+  // CHECK-NEXT: affine.store
+  // CHECK-NEXT: return
+  return %1 : f32
+}
+
+// -----
+
+// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
+// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32 {
+func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> f32 {
+  %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 6, 32] : memref<12x32xf32> into memref<2x6x32xf32>
+  %1 = affine.load %0[%arg1, %arg2, %arg3] : memref<2x6x32xf32>
+  return %1 : f32
+}
+// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]]] by (2, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG3]]] : memref<12x32xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+// -----
+
+// CHECK-LABEL: @fold_static_stride_subview_with_affine_load_store_collapse_shape
+// CHECK-SAME: (%[[ARG0:.*]]: memref<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg0 : memref<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
+  %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<2x6x32xf32> into memref<12x32xf32>
+  %1 = affine.load %0[%arg1, %arg2] : memref<12x32xf32>
+  return %1 : f32
+}
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (2, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<2x6x32xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+// -----
+
+// CHECK-LABEL: @fold_dynamic_size_collapse_shape_with_affine_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
+  %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<?x6x32xf32> into memref<?x32xf32>
+  %1 = affine.load %0[%arg1, %arg2] : memref<?x32xf32>
+  return %1 : f32
+}
+// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x6x32xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+// -----
+
+// CHECK-LABEL: @fold_fully_dynamic_size_collapse_shape_with_affine_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @fold_fully_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x?x?xf32>, %arg1 : index, %arg2 : index) -> f32 {
+  %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<?x?x?xf32> into memref<?x?xf32>
+  %1 = affine.load %0[%arg1, %arg2] : memref<?x?xf32>
+  return %1 : f32
+}
+// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, %[[SIZES]]#1)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x?x?xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+
+// -----
+
+// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_3d
+// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) -> f32 {
+func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: index) -> f32 {
+  %0 = memref.expand_shape %arg0 [[0, 1, 2], [3]] output_shape [2, 2, 3, 32] : memref<12x32xf32> into memref<2x2x3x32xf32>
+  %1 = affine.load %0[%arg1, %arg2, %arg3, %arg4] : memref<2x2x3x32xf32>
+  return %1 : f32
+}
+// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]], %[[ARG3]]] by (2, 2, 3)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG4]]] : memref<12x32xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
+// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
+// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
+func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
+  %subview = memref.subview %alloc[%c5, 0] [%c10, 16] [1, 1] : memref<2048x16xf32> to memref<?x16xf32, strided<[16, 1], offset: ?>>
+  %expand_shape = memref.expand_shape %subview [[0], [1, 2, 3]] output_shape [%sz0, 1, 8, 2] : memref<?x16xf32, strided<[16, 1], offset: ?>> into memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
+  %dim = memref.dim %expand_shape, %c0 : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
+
+  affine.for %arg6 = 0 to %dim step 64 {
+    affine.for %arg7 = 0 to 16 step 16 {
+      %dummy_load = affine.load %expand_shape[%arg6, 0, %arg7, %arg7] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
+      affine.store %dummy_load, %subview[%arg6, %arg7] : memref<?x16xf32, strided<[16, 1], offset: ?>>
+    }
+  }
+  return
+}
+// CHECK-NEXT:   %[[C0:.*]] = arith.constant 0
+// CHECK-NEXT:   memref.subview
+// CHECK-NEXT:   %[[EXPAND_SHAPE:.*]] = memref.expand_shape
+// CHECK-NEXT:   %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
+// CHECK-NEXT:   affine.for %[[ARG5:.*]] = 0 to %[[DIM]] step 64 {
+// CHECK-NEXT:   affine.for %[[ARG6:.*]] = 0 to 16 step 16 {
+// CHECK-NEXT:   %[[VAL0:.*]] = affine.linearize_index disjoint [%[[C0]], %[[ARG6]], %[[ARG6]]] by (1, 8, 2)
+// CHECK-NEXT:   %[[VAL1:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])[%[[ARG2]]]
+// CHECK-NEXT:   %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL1]], %[[VAL0]]] : memref<2048x16xf32>
+// CHECK-NEXT:   %[[VAL3:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])[%[[ARG2]]]
+// CHECK-NEXT:   affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG6]]] : memref<2048x16xf32>
+
+// -----
+
+// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_loops
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
+func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_loops(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
+  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 1024, 1024, 1] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
+  affine.for %arg3 = 0 to 1 {
+    affine.for %arg4 = 0 to 1024 {
+      affine.for %arg5 = 0 to 1020 {
+        affine.for %arg6 = 0 to 1 {
+          %1 = affine.load %0[%arg3, %arg4, %arg5, %arg6] : memref<1x1024x1024x1xf32>
+          affine.store %1, %arg1[%arg2] : memref<1xf32>
+        }
+      }
+    }
+  }
+  %2 = affine.load %arg1[%arg2] : memref<1xf32>
+  return %2 : f32
+}
+// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 1 {
+// CHECK-NEXT:  affine.for %[[ARG4:.*]] = 0 to 1024 {
+// CHECK-NEXT:   affine.for %[[ARG5:.*]] = 0 to 1020 {
+// CHECK-NEXT:    affine.for %[[ARG6:.*]] = 0 to 1 {
+// CHECK-NEXT:     %[[IDX1:.*]] = affine.linearize_index disjoint [%[[ARG3]], %[[ARG4]]] by (1, 1024)
+// CHECK-NEXT:     %[[IDX2:.*]] = affine.linearize_index disjoint [%[[ARG5]], %[[ARG6]]] by (1024, 1)
+// CHECK-NEXT:     affine.load %[[ARG0]][%[[IDX1]], %[[IDX2]]] : memref<1024x1024xf32>
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
+func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
+  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 1024, 1024, 1] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
+  affine.for %arg3 = 0 to 1 {
+    affine.for %arg4 = 0 to 1024 {
+      affine.for %arg5 = 0 to 1020 {
+        affine.for %arg6 = 0 to 1 {
+          %1 = affine.load %0[%arg3, %arg4 + %arg3, %arg5, %arg6] : memref<1x1024x1024x1xf32>
+          affine.store %1, %arg1[%arg2] : memref<1xf32>
+        }
+      }
+    }
+  }
+  %2 = affine.load %arg1[%arg2] : memref<1xf32>
+  return %2 : f32
+}
+// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 1 {
+// CHECK-NEXT:  affine.for %[[ARG4:.*]] = 0 to 1024 {
+// CHECK-NEXT:   affine.for %[[ARG5:.*]] = 0 to 1020 {
+// CHECK-NEXT:    affine.for %[[ARG6:.*]] = 0 to 1 {
+// CHECK-NEXT:      %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])
+// CHECK-NEXT:      %[[TMP2:.*]] = affine.linearize_index disjoint [%[[ARG3]], %[[TMP1]]] by (1, 1024)
+// CHECK-NEXT:      %[[TMP3:.*]] = affine.linearize_index disjoint [%[[ARG5]], %[[ARG6]]] by (1024, 1)
+// CHECK-NEXT:      affine.load %[[ARG0]][%[[TMP2]], %[[TMP3]]] : memref<1024x1024xf32>
+
+// -----
+
+// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_collapse_shape_with_0d_result
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1xf32>, %[[ARG1:.*]]: memref<1xf32>)
+func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape_with_0d_result(%arg0: memref<1xf32>, %arg1: memref<1xf32>) -> memref<1xf32> {
+  %0 = memref.collapse_shape %arg0 [] : memref<1xf32> into memref<f32>
+  affine.for %arg2 = 0 to 3 {
+    %1 = affine.load %0[] : memref<f32>
+    affine.store %1, %arg1[0] : memref<1xf32>
+  }
+  return %arg1 : memref<1xf32>
+}
+// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : index
+// CHECK-NEXT: affine.for %{{.*}} = 0 to 3 {
+// CHECK-NEXT:   affine.load %[[ARG0]][%[[ZERO]]] : memref<1xf32>
+

diff  --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index ca91b0141f593..93e5ba462584a 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -387,95 +387,6 @@ func.func @fold_masked_vector_transfer_write_with_rank_reducing_subview(
 
 // -----
 
-//  Test with affine.load/store ops. We only do a basic test here since the
-//  logic is identical to that with memref.load/store ops. The same affine.apply
-//  ops would be generated.
-
-// CHECK-LABEL: func @fold_static_stride_subview_with_affine_load_store
-func.func @fold_static_stride_subview_with_affine_load_store(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 {
-  %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, strided<[64, 3], offset: ?>>
-  %1 = affine.load %0[%arg3, %arg4] : memref<4x4xf32, strided<[64, 3], offset: ?>>
-  // CHECK-NEXT: affine.apply
-  // CHECK-NEXT: affine.apply
-  // CHECK-NEXT: affine.load
-  affine.store %1, %0[%arg3, %arg4] : memref<4x4xf32, strided<[64, 3], offset: ?>>
-  // CHECK-NEXT: affine.apply
-  // CHECK-NEXT: affine.apply
-  // CHECK-NEXT: affine.store
-  // CHECK-NEXT: return
-  return %1 : f32
-}
-
-// -----
-
-// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
-// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32 {
-func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> f32 {
-  %0 = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [2, 6, 32] : memref<12x32xf32> into memref<2x6x32xf32>
-  %1 = affine.load %0[%arg1, %arg2, %arg3] : memref<2x6x32xf32>
-  return %1 : f32
-}
-// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]]] by (2, 6)
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG3]]] : memref<12x32xf32>
-// CHECK-NEXT: return %[[RESULT]] : f32
-
-// -----
-
-// CHECK-LABEL: @fold_static_stride_subview_with_affine_load_store_collapse_shape
-// CHECK-SAME: (%[[ARG0:.*]]: memref<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
-func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg0 : memref<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
-  %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<2x6x32xf32> into memref<12x32xf32>
-  %1 = affine.load %0[%arg1, %arg2] : memref<12x32xf32>
-  return %1 : f32
-}
-// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (2, 6)
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<2x6x32xf32>
-// CHECK-NEXT: return %[[RESULT]] : f32
-
-// -----
-
-// CHECK-LABEL: @fold_dynamic_size_collapse_shape_with_affine_load
-// CHECK-SAME: (%[[ARG0:.*]]: memref<?x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
-func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
-  %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<?x6x32xf32> into memref<?x32xf32>
-  %1 = affine.load %0[%arg1, %arg2] : memref<?x32xf32>
-  return %1 : f32
-}
-// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
-// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, 6)
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x6x32xf32>
-// CHECK-NEXT: return %[[RESULT]] : f32
-
-// -----
-
-// CHECK-LABEL: @fold_fully_dynamic_size_collapse_shape_with_affine_load
-// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
-func.func @fold_fully_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x?x?xf32>, %arg1 : index, %arg2 : index) -> f32 {
-  %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<?x?x?xf32> into memref<?x?xf32>
-  %1 = affine.load %0[%arg1, %arg2] : memref<?x?xf32>
-  return %1 : f32
-}
-// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
-// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, %[[SIZES]]#1)
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x?x?xf32>
-// CHECK-NEXT: return %[[RESULT]] : f32
-
-
-// -----
-
-// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_3d
-// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) -> f32 {
-func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: index) -> f32 {
-  %0 = memref.expand_shape %arg0 [[0, 1, 2], [3]] output_shape [2, 2, 3, 32] : memref<12x32xf32> into memref<2x2x3x32xf32>
-  %1 = affine.load %0[%arg1, %arg2, %arg3, %arg4] : memref<2x2x3x32xf32>
-  return %1 : f32
-}
-// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]], %[[ARG3]]] by (2, 2, 3)
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG4]]] : memref<12x32xf32>
-// CHECK-NEXT: return %[[RESULT]] : f32
-
-// -----
-
 // CHECK-LABEL: fold_dynamic_subview_with_memref_load_expand_shape
 // CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32
 func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0: index) -> f32 {
@@ -510,95 +421,9 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
 
 // -----
 
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
-// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
-// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
-func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
-  %subview = memref.subview %alloc[%c5, 0] [%c10, 16] [1, 1] : memref<2048x16xf32> to memref<?x16xf32, strided<[16, 1], offset: ?>>
-  %expand_shape = memref.expand_shape %subview [[0], [1, 2, 3]] output_shape [%sz0, 1, 8, 2] : memref<?x16xf32, strided<[16, 1], offset: ?>> into memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
-  %dim = memref.dim %expand_shape, %c0 : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
-
-  affine.for %arg6 = 0 to %dim step 64 {
-    affine.for %arg7 = 0 to 16 step 16 {
-      %dummy_load = affine.load %expand_shape[%arg6, 0, %arg7, %arg7] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
-      affine.store %dummy_load, %subview[%arg6, %arg7] : memref<?x16xf32, strided<[16, 1], offset: ?>>
-    }
-  }
-  return
-}
-// CHECK-NEXT:   %[[C0:.*]] = arith.constant 0
-// CHECK-NEXT:   memref.subview
-// CHECK-NEXT:   %[[EXPAND_SHAPE:.*]] = memref.expand_shape
-// CHECK-NEXT:   %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
-// CHECK-NEXT:   affine.for %[[ARG5:.*]] = 0 to %[[DIM]] step 64 {
-// CHECK-NEXT:   affine.for %[[ARG6:.*]] = 0 to 16 step 16 {
-// CHECK-NEXT:   %[[VAL0:.*]] = affine.linearize_index disjoint [%[[C0]], %[[ARG6]], %[[ARG6]]] by (1, 8, 2)
-// CHECK-NEXT:   %[[VAL1:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])[%[[ARG2]]]
-// CHECK-NEXT:   %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL1]], %[[VAL0]]] : memref<2048x16xf32>
-// CHECK-NEXT:   %[[VAL3:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])[%[[ARG2]]]
-// CHECK-NEXT:   affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG6]]] : memref<2048x16xf32>
-
-// -----
-
-// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
-// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
-func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
-  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 1024, 1024, 1] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
-  affine.for %arg3 = 0 to 1 {
-    affine.for %arg4 = 0 to 1024 {
-      affine.for %arg5 = 0 to 1020 {
-        affine.for %arg6 = 0 to 1 {
-          %1 = affine.load %0[%arg3, %arg4, %arg5, %arg6] : memref<1x1024x1024x1xf32>
-          affine.store %1, %arg1[%arg2] : memref<1xf32>
-        }
-      }
-    }
-  }
-  %2 = affine.load %arg1[%arg2] : memref<1xf32>
-  return %2 : f32
-}
-// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 1 {
-// CHECK-NEXT:  affine.for %[[ARG4:.*]] = 0 to 1024 {
-// CHECK-NEXT:   affine.for %[[ARG5:.*]] = 0 to 1020 {
-// CHECK-NEXT:    affine.for %[[ARG6:.*]] = 0 to 1 {
-// CHECK-NEXT:     %[[IDX1:.*]] = affine.linearize_index disjoint [%[[ARG3]], %[[ARG4]]] by (1, 1024)
-// CHECK-NEXT:     %[[IDX2:.*]] = affine.linearize_index disjoint [%[[ARG5]], %[[ARG6]]] by (1024, 1)
-// CHECK-NEXT:     affine.load %[[ARG0]][%[[IDX1]], %[[IDX2]]] : memref<1024x1024xf32>
-
-// -----
-
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
-// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression
-// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
-func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
-  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 1024, 1024, 1] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
-  affine.for %arg3 = 0 to 1 {
-    affine.for %arg4 = 0 to 1024 {
-      affine.for %arg5 = 0 to 1020 {
-        affine.for %arg6 = 0 to 1 {
-          %1 = affine.load %0[%arg3, %arg4 + %arg3, %arg5, %arg6] : memref<1x1024x1024x1xf32>
-          affine.store %1, %arg1[%arg2] : memref<1xf32>
-        }
-      }
-    }
-  }
-  %2 = affine.load %arg1[%arg2] : memref<1xf32>
-  return %2 : f32
-}
-// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 1 {
-// CHECK-NEXT:  affine.for %[[ARG4:.*]] = 0 to 1024 {
-// CHECK-NEXT:   affine.for %[[ARG5:.*]] = 0 to 1020 {
-// CHECK-NEXT:    affine.for %[[ARG6:.*]] = 0 to 1 {
-// CHECK-NEXT:      %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])
-// CHECK-NEXT:      %[[TMP2:.*]] = affine.linearize_index disjoint [%[[ARG3]], %[[TMP1]]] by (1, 1024)
-// CHECK-NEXT:      %[[TMP3:.*]] = affine.linearize_index disjoint [%[[ARG5]], %[[ARG6]]] by (1024, 1)
-// CHECK-NEXT:      affine.load %[[ARG0]][%[[TMP2]], %[[TMP3]]] : memref<1024x1024xf32>
-
-// -----
-
-// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index
+// CHECK-LABEL: fold_static_stride_subview_with_memref_expand_shape_with_constant_access_index
 // CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
-func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
+func.func @fold_static_stride_subview_with_memref_expand_shape_with_constant_access_index(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
   %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 1024, 1024, 1] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
   %cst = arith.constant 0 : index
   affine.for %arg3 = 0 to 1 {
@@ -625,22 +450,6 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_c
 
 // -----
 
-// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_collapse_shape_with_0d_result
-// CHECK-SAME: (%[[ARG0:.*]]: memref<1xf32>, %[[ARG1:.*]]: memref<1xf32>)
-func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape_with_0d_result(%arg0: memref<1xf32>, %arg1: memref<1xf32>) -> memref<1xf32> {
-  %0 = memref.collapse_shape %arg0 [] : memref<1xf32> into memref<f32>
-  affine.for %arg2 = 0 to 3 {
-    %1 = affine.load %0[] : memref<f32>
-    affine.store %1, %arg1[0] : memref<1xf32>
-  }
-  return %arg1 : memref<1xf32>
-}
-// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : index
-// CHECK-NEXT: affine.for %{{.*}} = 0 to 3 {
-// CHECK-NEXT:   affine.load %[[ARG0]][%[[ZERO]]] : memref<1xf32>
-
-// -----
-
 //       CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 + 2)>
 // CHECK-LABEL: func @subview_of_subview(
 //  CHECK-SAME:     %[[m:.*]]: memref<8x1024xf32, 3>, %[[pos:.*]]: index


        


More information about the Mlir-commits mailing list