[Mlir-commits] [mlir] 1b002d2 - Fold memref.expand_shape and memref.collapse_shape ops

Uday Bondhugula llvmlistbot at llvm.org
Sat Aug 27 18:32:52 PDT 2022


Author: Arnab Dutta
Date: 2022-08-28T06:56:06+05:30
New Revision: 1b002d27683522e15d6ef3bccb74e751bbf56e74

URL: https://github.com/llvm/llvm-project/commit/1b002d27683522e15d6ef3bccb74e751bbf56e74
DIFF: https://github.com/llvm/llvm-project/commit/1b002d27683522e15d6ef3bccb74e751bbf56e74.diff

LOG: Fold memref.expand_shape and memref.collapse_shape ops

Fold memref.expand_shape and memref.collapse_shape ops into their
memref/affine load/store ops.

Reviewed By: bondhugula, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D128986

Added: 
    mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
    mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Modified: 
    mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
    mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
    mlir/include/mlir/Dialect/Utils/IndexingUtils.h
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
    mlir/lib/Dialect/Utils/IndexingUtils.cpp
    mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp

Removed: 
    mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
    mlir/test/Dialect/MemRef/fold-subview-ops.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index 4c7781d06a2ac..82c2b5bceb9a2 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -39,9 +39,9 @@ class AllocOp;
 /// Collects a set of patterns to rewrite ops within the memref dialect.
 void populateExpandOpsPatterns(RewritePatternSet &patterns);
 
-/// Appends patterns for folding memref.subview ops into consumer load/store ops
-/// into `patterns`.
-void populateFoldSubViewOpPatterns(RewritePatternSet &patterns);
+/// Appends patterns for folding memref aliasing ops into consumer load/store
+/// ops into `patterns`.
+void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns);
 
 /// Appends patterns that resolve `memref.dim` operations with values that are
 /// defined by operations that implement the
@@ -91,9 +91,9 @@ LogicalResult multiBuffer(memref::AllocOp allocOp, unsigned multiplier);
 /// `memref_reinterpret_cast`.
 std::unique_ptr<Pass> createExpandOpsPass();
 
-/// Creates an operation pass to fold memref.subview ops into consumer
+/// Creates an operation pass to fold memref aliasing ops into consumer
 /// load/store ops into `patterns`.
-std::unique_ptr<Pass> createFoldSubViewOpsPass();
+std::unique_ptr<Pass> createFoldMemRefAliasOpsPass();
 
 /// Creates an interprocedural pass to normalize memrefs to have a trivial
 /// (identity) layout map.

diff  --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index ee769c11590b9..5ac124a0fcaa8 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -16,13 +16,13 @@ def ExpandOps : Pass<"memref-expand"> {
   let constructor = "mlir::memref::createExpandOpsPass()";
 }
 
-def FoldSubViewOps : Pass<"fold-memref-subview-ops"> {
-  let summary = "Fold memref.subview ops into consumer load/store ops";
+def FoldMemRefAliasOps : Pass<"fold-memref-alias-ops"> {
+  let summary = "Fold memref alias ops into consumer load/store ops";
   let description = [{
-    The pass folds loading/storing from/to subview ops to loading/storing
+    The pass folds loading/storing from/to memref aliasing ops to loading/storing
     from/to the original memref.
   }];
-  let constructor = "mlir::memref::createFoldSubViewOpsPass()";
+  let constructor = "mlir::memref::createFoldMemRefAliasOpsPass()";
   let dependentDialects = [
       "AffineDialect", "memref::MemRefDialect", "vector::VectorDialect"
   ];

diff  --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 3f2dd00c696f8..c1857462d4c67 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_DIALECT_UTILS_INDEXINGUTILS_H
 #define MLIR_DIALECT_UTILS_INDEXINGUTILS_H
 
+#include "mlir/IR/Builders.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallVector.h"
@@ -47,6 +48,15 @@ void applyPermutationToVector(SmallVector<T, N> &inVec,
 SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
                                        unsigned dropFront = 0,
                                        unsigned dropBack = 0);
+
+/// Computes and returns linearized affine expression w.r.t. `basis`.
+mlir::AffineExpr getLinearAffineExpr(ArrayRef<int64_t> basis, mlir::Builder &b);
+
+/// Given the strides in the dimension space, returns the affine expressions for
+/// vector-space offsets in each dimension for a de-linearized index.
+SmallVector<mlir::AffineExpr, 4>
+getDelinearizedAffineExpr(ArrayRef<int64_t> strides, mlir::Builder &b);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_UTILS_INDEXINGUTILS_H

diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index aa6624f07a2dc..b7783c601a92b 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1486,6 +1486,13 @@ def Vector_TransferWriteOp :
                    "ValueRange":$indices,
                    CArg<"Optional<ArrayRef<bool>>", "::llvm::None">:$inBounds)>,
   ];
+
+  let extraClassDeclaration = [{
+    /// This method is added to maintain uniformity with load/store
+    ///  ops of other dialects.
+    Value getValue() { return getVector(); }
+  }];
+
   let hasFolder = 1;
   let hasCanonicalizer = 1;
   let hasCustomAssemblyFormat = 1;

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index 1f8167cc318e5..f85b6e50e91e4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -1,7 +1,7 @@
 add_mlir_dialect_library(MLIRMemRefTransforms
   ComposeSubView.cpp
   ExpandOps.cpp
-  FoldSubViewOps.cpp
+  FoldMemRefAliasOps.cpp
   MultiBuffer.cpp
   NormalizeMemRefs.cpp
   ResolveShapedTypeResultDims.cpp

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
new file mode 100644
index 0000000000000..3ce18ec442c04
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -0,0 +1,562 @@
+//===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transformation pass folds loading/storing from/to subview ops into
+// loading/storing from/to the original memref.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+/// Given the 'indices' of a load/store operation where the memref is a result
+/// of a expand_shape op, returns the indices w.r.t to the source memref of the
+/// expand_shape op. For example
+///
+/// %0 = ... : memref<12x42xf32>
+/// %1 = memref.expand_shape %0 [[0, 1], [2]]
+///    : memref<12x42xf32> into memref<2x6x42xf32>
+/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
+///
+/// could be folded into
+///
+/// %2 = load %0[6 * i1 + i2, %i3] :
+///          memref<12x42xf32>
+static LogicalResult
+resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
+                                memref::ExpandShapeOp expandShapeOp,
+                                ValueRange indices,
+                                SmallVectorImpl<Value> &sourceIndices) {
+  for (SmallVector<int64_t, 2> groups :
+       expandShapeOp.getReassociationIndices()) {
+    assert(!groups.empty() && "association indices groups cannot be empty");
+    unsigned groupSize = groups.size();
+    SmallVector<int64_t> suffixProduct(groupSize);
+    // Calculate suffix product of dimension sizes for all dimensions of expand
+    // shape op result.
+    suffixProduct[groupSize - 1] = 1;
+    for (unsigned i = groupSize - 1; i > 0; i--)
+      suffixProduct[i - 1] =
+          suffixProduct[i] *
+          expandShapeOp.getType().cast<MemRefType>().getDimSize(groups[i]);
+    SmallVector<Value> dynamicIndices(groupSize);
+    for (unsigned i = 0; i < groupSize; i++)
+      dynamicIndices[i] = indices[groups[i]];
+    // Construct the expression for the index value w.r.t to expand shape op
+    // source corresponding the indices wrt to expand shape op result.
+    AffineExpr srcIndexExpr = getLinearAffineExpr(suffixProduct, rewriter);
+    sourceIndices.push_back(rewriter.create<AffineApplyOp>(
+        loc,
+        AffineMap::get(/*numDims=*/groupSize, /*numSymbols=*/0, srcIndexExpr),
+        dynamicIndices));
+  }
+  return success();
+}
+
+/// Given the 'indices' of a load/store operation where the memref is a result
+/// of a collapse_shape op, returns the indices w.r.t to the source memref of
+/// the collapse_shape op. For example
+///
+/// %0 = ... : memref<2x6x42xf32>
+/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
+///    : memref<2x6x42xf32> into memref<12x42xf32>
+/// %2 = load %1[%i1, %i2] : memref<12x42xf32>
+///
+/// could be folded into
+///
+/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
+///          memref<2x6x42xf32>
+static LogicalResult
+resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
+                                  memref::CollapseShapeOp collapseShapeOp,
+                                  ValueRange indices,
+                                  SmallVectorImpl<Value> &sourceIndices) {
+  unsigned cnt = 0;
+  SmallVector<Value> tmp(indices.size());
+  SmallVector<Value> dynamicIndices;
+  for (SmallVector<int64_t, 2> groups :
+       collapseShapeOp.getReassociationIndices()) {
+    assert(!groups.empty() && "association indices groups cannot be empty");
+    dynamicIndices.push_back(indices[cnt++]);
+    unsigned groupSize = groups.size();
+    SmallVector<int64_t> suffixProduct(groupSize);
+    // Calculate suffix product for all collapse op source dimension sizes.
+    suffixProduct[groupSize - 1] = 1;
+    for (unsigned i = groupSize - 1; i > 0; i--)
+      suffixProduct[i - 1] =
+          suffixProduct[i] * collapseShapeOp.getSrcType().getDimSize(groups[i]);
+    // Derive the index values along all dimensions of the source corresponding
+    // to the index wrt to collapsed shape op output.
+    SmallVector<AffineExpr, 4> srcIndexExpr =
+        getDelinearizedAffineExpr(suffixProduct, rewriter);
+    for (unsigned i = 0; i < groupSize; i++)
+      sourceIndices.push_back(rewriter.create<AffineApplyOp>(
+          loc, AffineMap::get(/*numDims=*/1, /*numSymbols=*/0, srcIndexExpr[i]),
+          dynamicIndices));
+    dynamicIndices.clear();
+  }
+  if (collapseShapeOp.getReassociationIndices().empty()) {
+    auto zeroAffineMap = rewriter.getConstantAffineMap(0);
+    unsigned srcRank =
+        collapseShapeOp.getViewSource().getType().cast<MemRefType>().getRank();
+    for (unsigned i = 0; i < srcRank; i++)
+      sourceIndices.push_back(
+          rewriter.create<AffineApplyOp>(loc, zeroAffineMap, dynamicIndices));
+  }
+  return success();
+}
+
+/// Given the 'indices' of an load/store operation where the memref is a result
+/// of a subview op, returns the indices w.r.t to the source memref of the
+/// subview op. For example
+///
+/// %0 = ... : memref<12x42xf32>
+/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
+///          memref<4x4xf32, offset=?, strides=[?, ?]>
+/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
+///
+/// could be folded into
+///
+/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
+///          memref<12x42xf32>
+static LogicalResult
+resolveSourceIndicesSubView(Location loc, PatternRewriter &rewriter,
+                            memref::SubViewOp subViewOp, ValueRange indices,
+                            SmallVectorImpl<Value> &sourceIndices) {
+  SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets();
+  SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
+  SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
+
+  SmallVector<Value> useIndices;
+  // Check if this is rank-reducing case. Then for every unit-dim size add a
+  // zero to the indices.
+  unsigned resultDim = 0;
+  llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
+  for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) {
+    if (unusedDims.test(dim))
+      useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
+    else
+      useIndices.push_back(indices[resultDim++]);
+  }
+  if (useIndices.size() != mixedOffsets.size())
+    return failure();
+  sourceIndices.resize(useIndices.size());
+  for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
+    SmallVector<Value> dynamicOperands;
+    AffineExpr expr = rewriter.getAffineDimExpr(0);
+    unsigned numSymbols = 0;
+    dynamicOperands.push_back(useIndices[index]);
+
+    // Multiply the stride;
+    if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) {
+      expr = expr * attr.cast<IntegerAttr>().getInt();
+    } else {
+      dynamicOperands.push_back(mixedStrides[index].get<Value>());
+      expr = expr * rewriter.getAffineSymbolExpr(numSymbols++);
+    }
+
+    // Add the offset.
+    if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) {
+      expr = expr + attr.cast<IntegerAttr>().getInt();
+    } else {
+      dynamicOperands.push_back(mixedOffsets[index].get<Value>());
+      expr = expr + rewriter.getAffineSymbolExpr(numSymbols++);
+    }
+    Location loc = subViewOp.getLoc();
+    sourceIndices[index] = rewriter.create<AffineApplyOp>(
+        loc, AffineMap::get(1, numSymbols, expr), dynamicOperands);
+  }
+  return success();
+}
+
+/// Helpers to access the memref operand for each op.
+template <typename LoadOrStoreOpTy>
+static Value getMemRefOperand(LoadOrStoreOpTy op) {
+  return op.getMemref();
+}
+
+static Value getMemRefOperand(vector::TransferReadOp op) {
+  return op.getSource();
+}
+
+static Value getMemRefOperand(vector::TransferWriteOp op) {
+  return op.getSource();
+}
+
+/// Given the permutation map of the original
+/// `vector.transfer_read`/`vector.transfer_write` operations compute the
+/// permutation map to use after the subview is folded with it.
+static AffineMapAttr getPermutationMapAttr(MLIRContext *context,
+                                           memref::SubViewOp subViewOp,
+                                           AffineMap currPermutationMap) {
+  llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
+  SmallVector<AffineExpr> exprs;
+  int64_t sourceRank = subViewOp.getSourceType().getRank();
+  for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
+    if (unusedDims.test(dim))
+      continue;
+    exprs.push_back(getAffineDimExpr(dim, context));
+  }
+  auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context);
+  return AffineMapAttr::get(
+      currPermutationMap.compose(resultDimToSourceDimMap));
+}
+
+//===----------------------------------------------------------------------===//
+// Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Merges subview operation with load/transferRead operation.
+template <typename OpTy>
+class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
+public:
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy loadOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Merges expand_shape operation with load/transferRead operation.
+template <typename OpTy>
+class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
+public:
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy loadOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Merges collapse_shape operation with load/transferRead operation.
+template <typename OpTy>
+class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
+public:
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy loadOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Merges subview operation with store/transferWriteOp operation.
+template <typename OpTy>
+class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
+public:
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy storeOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Merges expand_shape operation with store/transferWriteOp operation.
+template <typename OpTy>
+class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
+public:
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy storeOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Merges collapse_shape operation with store/transferWriteOp operation.
+template <typename OpTy>
+class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
+public:
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy storeOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+} // namespace
+
+static SmallVector<Value>
+calculateExpandedAccessIndices(AffineMap affineMap, SmallVector<Value> indices,
+                               Location loc, PatternRewriter &rewriter) {
+  SmallVector<Value> expandedIndices;
+  for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++)
+    expandedIndices.push_back(
+        rewriter.create<AffineApplyOp>(loc, affineMap.getSubMap({i}), indices));
+  return expandedIndices;
+}
+
+template <typename OpTy>
+LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
+    OpTy loadOp, PatternRewriter &rewriter) const {
+  auto subViewOp =
+      getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
+
+  if (!subViewOp)
+    return failure();
+
+  SmallVector<Value> indices(loadOp.getIndices().begin(),
+                             loadOp.getIndices().end());
+  // For affine ops, we need to apply the map to get the operands to get the
+  // "actual" indices.
+  if (auto affineLoadOp = dyn_cast<AffineLoadOp>(loadOp.getOperation())) {
+    AffineMap affineMap = affineLoadOp.getAffineMap();
+    auto expandedIndices = calculateExpandedAccessIndices(
+        affineMap, indices, loadOp.getLoc(), rewriter);
+    indices.assign(expandedIndices.begin(), expandedIndices.end());
+  }
+  SmallVector<Value, 4> sourceIndices;
+  if (failed(resolveSourceIndicesSubView(loadOp.getLoc(), rewriter, subViewOp,
+                                         indices, sourceIndices)))
+    return failure();
+  llvm::TypeSwitch<Operation *, void>(loadOp)
+      .Case<AffineLoadOp, memref::LoadOp>([&](auto op) {
+        rewriter.replaceOpWithNewOp<decltype(op)>(loadOp, subViewOp.source(),
+                                                  sourceIndices);
+      })
+      .Case([&](vector::TransferReadOp transferReadOp) {
+        if (transferReadOp.getTransferRank() == 0) {
+          // TODO: Propagate the error.
+          return;
+        }
+        rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
+            transferReadOp, transferReadOp.getVectorType(), subViewOp.source(),
+            sourceIndices,
+            getPermutationMapAttr(rewriter.getContext(), subViewOp,
+                                  transferReadOp.getPermutationMap()),
+            transferReadOp.getPadding(),
+            /*mask=*/Value(), transferReadOp.getInBoundsAttr());
+      })
+      .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
+  return success();
+}
+
+template <typename OpTy>
+LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
+    OpTy loadOp, PatternRewriter &rewriter) const {
+  auto expandShapeOp =
+      getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>();
+
+  if (!expandShapeOp)
+    return failure();
+
+  SmallVector<Value> indices(loadOp.getIndices().begin(),
+                             loadOp.getIndices().end());
+  // For affine ops, we need to apply the map to get the operands to get the
+  // "actual" indices.
+  if (auto affineLoadOp = dyn_cast<AffineLoadOp>(loadOp.getOperation())) {
+    AffineMap affineMap = affineLoadOp.getAffineMap();
+    auto expandedIndices = calculateExpandedAccessIndices(
+        affineMap, indices, loadOp.getLoc(), rewriter);
+    indices.assign(expandedIndices.begin(), expandedIndices.end());
+  }
+  SmallVector<Value, 4> sourceIndices;
+  if (failed(resolveSourceIndicesExpandShape(
+          loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
+    return failure();
+  llvm::TypeSwitch<Operation *, void>(loadOp)
+      .Case<AffineLoadOp, memref::LoadOp>([&](auto op) {
+        rewriter.replaceOpWithNewOp<decltype(op)>(
+            loadOp, expandShapeOp.getViewSource(), sourceIndices);
+      })
+      .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
+  return success();
+}
+
+template <typename OpTy>
+LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
+    OpTy loadOp, PatternRewriter &rewriter) const {
+  auto collapseShapeOp = getMemRefOperand(loadOp)
+                             .template getDefiningOp<memref::CollapseShapeOp>();
+
+  if (!collapseShapeOp)
+    return failure();
+
+  SmallVector<Value> indices(loadOp.getIndices().begin(),
+                             loadOp.getIndices().end());
+  // For affine ops, we need to apply the map to get the operands to get the
+  // "actual" indices.
+  if (auto affineLoadOp = dyn_cast<AffineLoadOp>(loadOp.getOperation())) {
+    AffineMap affineMap = affineLoadOp.getAffineMap();
+    auto expandedIndices = calculateExpandedAccessIndices(
+        affineMap, indices, loadOp.getLoc(), rewriter);
+    indices.assign(expandedIndices.begin(), expandedIndices.end());
+  }
+  SmallVector<Value, 4> sourceIndices;
+  if (failed(resolveSourceIndicesCollapseShape(
+          loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
+    return failure();
+  llvm::TypeSwitch<Operation *, void>(loadOp)
+      .Case<AffineLoadOp, memref::LoadOp>([&](auto op) {
+        rewriter.replaceOpWithNewOp<decltype(op)>(
+            loadOp, collapseShapeOp.getViewSource(), sourceIndices);
+      })
+      .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
+  return success();
+}
+
+template <typename OpTy>
+LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
+    OpTy storeOp, PatternRewriter &rewriter) const {
+  auto subViewOp =
+      getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
+
+  if (!subViewOp)
+    return failure();
+
+  SmallVector<Value> indices(storeOp.getIndices().begin(),
+                             storeOp.getIndices().end());
+  // For affine ops, we need to apply the map to get the operands to get the
+  // "actual" indices.
+  if (auto affineStoreOp = dyn_cast<AffineStoreOp>(storeOp.getOperation())) {
+    AffineMap affineMap = affineStoreOp.getAffineMap();
+    auto expandedIndices = calculateExpandedAccessIndices(
+        affineMap, indices, storeOp.getLoc(), rewriter);
+    indices.assign(expandedIndices.begin(), expandedIndices.end());
+  }
+  SmallVector<Value, 4> sourceIndices;
+  if (failed(resolveSourceIndicesSubView(storeOp.getLoc(), rewriter, subViewOp,
+                                         indices, sourceIndices)))
+    return failure();
+  llvm::TypeSwitch<Operation *, void>(storeOp)
+      .Case<AffineStoreOp, memref::StoreOp>([&](auto op) {
+        rewriter.replaceOpWithNewOp<decltype(op)>(
+            storeOp, storeOp.getValue(), subViewOp.source(), sourceIndices);
+      })
+      .Case([&](vector::TransferWriteOp op) {
+        // TODO: support 0-d corner case.
+        if (op.getTransferRank() == 0)
+          return;
+        rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+            op, op.getValue(), subViewOp.source(), sourceIndices,
+            getPermutationMapAttr(rewriter.getContext(), subViewOp,
+                                  op.getPermutationMap()),
+            op.getInBoundsAttr());
+      })
+      .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
+  return success();
+}
+
+template <typename OpTy>
+LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
+    OpTy storeOp, PatternRewriter &rewriter) const {
+  auto expandShapeOp =
+      getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>();
+
+  if (!expandShapeOp)
+    return failure();
+
+  SmallVector<Value> indices(storeOp.getIndices().begin(),
+                             storeOp.getIndices().end());
+  // For affine ops, we need to apply the map to get the operands to get the
+  // "actual" indices.
+  if (auto affineStoreOp = dyn_cast<AffineStoreOp>(storeOp.getOperation())) {
+    AffineMap affineMap = affineStoreOp.getAffineMap();
+    auto expandedIndices = calculateExpandedAccessIndices(
+        affineMap, indices, storeOp.getLoc(), rewriter);
+    indices.assign(expandedIndices.begin(), expandedIndices.end());
+  }
+  SmallVector<Value, 4> sourceIndices;
+  if (failed(resolveSourceIndicesExpandShape(
+          storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
+    return failure();
+  llvm::TypeSwitch<Operation *, void>(storeOp)
+      .Case<AffineStoreOp, memref::StoreOp>([&](auto op) {
+        rewriter.replaceOpWithNewOp<decltype(op)>(storeOp, storeOp.getValue(),
+                                                  expandShapeOp.getViewSource(),
+                                                  sourceIndices);
+      })
+      .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
+  return success();
+}
+
+template <typename OpTy>
+LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
+    OpTy storeOp, PatternRewriter &rewriter) const {
+  auto collapseShapeOp = getMemRefOperand(storeOp)
+                             .template getDefiningOp<memref::CollapseShapeOp>();
+
+  if (!collapseShapeOp)
+    return failure();
+
+  SmallVector<Value> indices(storeOp.getIndices().begin(),
+                             storeOp.getIndices().end());
+  // For affine ops, we need to apply the map to get the operands to get the
+  // "actual" indices.
+  if (auto affineStoreOp = dyn_cast<AffineStoreOp>(storeOp.getOperation())) {
+    AffineMap affineMap = affineStoreOp.getAffineMap();
+    auto expandedIndices = calculateExpandedAccessIndices(
+        affineMap, indices, storeOp.getLoc(), rewriter);
+    indices.assign(expandedIndices.begin(), expandedIndices.end());
+  }
+  SmallVector<Value, 4> sourceIndices;
+  if (failed(resolveSourceIndicesCollapseShape(
+          storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
+    return failure();
+  llvm::TypeSwitch<Operation *, void>(storeOp)
+      .Case<AffineStoreOp, memref::StoreOp>([&](auto op) {
+        rewriter.replaceOpWithNewOp<decltype(op)>(
+            storeOp, storeOp.getValue(), collapseShapeOp.getViewSource(),
+            sourceIndices);
+      })
+      .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
+  return success();
+}
+
+void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
+  patterns.add<LoadOpOfSubViewOpFolder<AffineLoadOp>,
+               LoadOpOfSubViewOpFolder<memref::LoadOp>,
+               LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
+               StoreOpOfSubViewOpFolder<AffineStoreOp>,
+               StoreOpOfSubViewOpFolder<memref::StoreOp>,
+               StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
+               LoadOpOfExpandShapeOpFolder<AffineLoadOp>,
+               LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
+               StoreOpOfExpandShapeOpFolder<AffineStoreOp>,
+               StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
+               LoadOpOfCollapseShapeOpFolder<AffineLoadOp>,
+               LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
+               StoreOpOfCollapseShapeOpFolder<AffineStoreOp>,
+               StoreOpOfCollapseShapeOpFolder<memref::StoreOp>>(
+      patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// Pass registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+#define GEN_PASS_CLASSES
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+
+struct FoldMemRefAliasOpsPass final
+    : public FoldMemRefAliasOpsBase<FoldMemRefAliasOpsPass> {
+  void runOnOperation() override;
+};
+
+} // namespace
+
+void FoldMemRefAliasOpsPass::runOnOperation() {
+  RewritePatternSet patterns(&getContext());
+  memref::populateFoldMemRefAliasOpPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
+                                     std::move(patterns));
+}
+
+std::unique_ptr<Pass> memref::createFoldMemRefAliasOpsPass() {
+  return std::make_unique<FoldMemRefAliasOpsPass>();
+}

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
deleted file mode 100644
index 85d28ee5a020f..0000000000000
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
+++ /dev/null
@@ -1,276 +0,0 @@
-//===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This transformation pass folds loading/storing from/to subview ops into
-// loading/storing from/to the original memref.
-//
-//===----------------------------------------------------------------------===//
-
-#include "PassDetail.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/MemRef/Transforms/Passes.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/ADT/SmallBitVector.h"
-
-using namespace mlir;
-
-//===----------------------------------------------------------------------===//
-// Utility functions
-//===----------------------------------------------------------------------===//
-
-/// Given the 'indices' of an load/store operation where the memref is a result
-/// of a subview op, returns the indices w.r.t to the source memref of the
-/// subview op. For example
-///
-/// %0 = ... : memref<12x42xf32>
-/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
-///          memref<4x4xf32, offset=?, strides=[?, ?]>
-/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
-///
-/// could be folded into
-///
-/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
-///          memref<12x42xf32>
-static LogicalResult
-resolveSourceIndices(Location loc, PatternRewriter &rewriter,
-                     memref::SubViewOp subViewOp, ValueRange indices,
-                     SmallVectorImpl<Value> &sourceIndices) {
-  SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets();
-  SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
-  SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
-
-  SmallVector<Value> useIndices;
-  // Check if this is rank-reducing case. Then for every unit-dim size add a
-  // zero to the indices.
-  unsigned resultDim = 0;
-  llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
-  for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) {
-    if (unusedDims.test(dim))
-      useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
-    else
-      useIndices.push_back(indices[resultDim++]);
-  }
-  if (useIndices.size() != mixedOffsets.size())
-    return failure();
-  sourceIndices.resize(useIndices.size());
-  for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
-    SmallVector<Value> dynamicOperands;
-    AffineExpr expr = rewriter.getAffineDimExpr(0);
-    unsigned numSymbols = 0;
-    dynamicOperands.push_back(useIndices[index]);
-
-    // Multiply the stride;
-    if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) {
-      expr = expr * attr.cast<IntegerAttr>().getInt();
-    } else {
-      dynamicOperands.push_back(mixedStrides[index].get<Value>());
-      expr = expr * rewriter.getAffineSymbolExpr(numSymbols++);
-    }
-
-    // Add the offset.
-    if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) {
-      expr = expr + attr.cast<IntegerAttr>().getInt();
-    } else {
-      dynamicOperands.push_back(mixedOffsets[index].get<Value>());
-      expr = expr + rewriter.getAffineSymbolExpr(numSymbols++);
-    }
-    Location loc = subViewOp.getLoc();
-    sourceIndices[index] = rewriter.create<AffineApplyOp>(
-        loc, AffineMap::get(1, numSymbols, expr), dynamicOperands);
-  }
-  return success();
-}
-
-/// Helpers to access the memref operand for each op.
-template <typename LoadOrStoreOpTy>
-static Value getMemRefOperand(LoadOrStoreOpTy op) {
-  return op.getMemref();
-}
-
-static Value getMemRefOperand(vector::TransferReadOp op) {
-  return op.getSource();
-}
-
-static Value getMemRefOperand(vector::TransferWriteOp op) {
-  return op.getSource();
-}
-
-/// Given the permutation map of the original
-/// `vector.transfer_read`/`vector.transfer_write` operations compute the
-/// permutation map to use after the subview is folded with it.
-static AffineMapAttr getPermutationMapAttr(MLIRContext *context,
-                                           memref::SubViewOp subViewOp,
-                                           AffineMap currPermutationMap) {
-  llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
-  SmallVector<AffineExpr> exprs;
-  int64_t sourceRank = subViewOp.getSourceType().getRank();
-  for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
-    if (unusedDims.test(dim))
-      continue;
-    exprs.push_back(getAffineDimExpr(dim, context));
-  }
-  auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context);
-  return AffineMapAttr::get(
-      currPermutationMap.compose(resultDimToSourceDimMap));
-}
-
-//===----------------------------------------------------------------------===//
-// Patterns
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Merges subview operation with load/transferRead operation.
-template <typename OpTy>
-class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
-public:
-  using OpRewritePattern<OpTy>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(OpTy loadOp,
-                                PatternRewriter &rewriter) const override;
-
-private:
-  void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp,
-                 ArrayRef<Value> sourceIndices,
-                 PatternRewriter &rewriter) const;
-};
-
-/// Merges subview operation with store/transferWriteOp operation.
-template <typename OpTy>
-class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
-public:
-  using OpRewritePattern<OpTy>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(OpTy storeOp,
-                                PatternRewriter &rewriter) const override;
-
-private:
-  void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp,
-                 ArrayRef<Value> sourceIndices,
-                 PatternRewriter &rewriter) const;
-};
-
-template <typename LoadOpTy>
-void LoadOpOfSubViewFolder<LoadOpTy>::replaceOp(
-    LoadOpTy loadOp, memref::SubViewOp subViewOp, ArrayRef<Value> sourceIndices,
-    PatternRewriter &rewriter) const {
-  rewriter.replaceOpWithNewOp<LoadOpTy>(loadOp, subViewOp.getSource(),
-                                        sourceIndices);
-}
-
-template <>
-void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
-    vector::TransferReadOp transferReadOp, memref::SubViewOp subViewOp,
-    ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
-  // TODO: support 0-d corner case.
-  if (transferReadOp.getTransferRank() == 0)
-    return;
-  rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
-      transferReadOp, transferReadOp.getVectorType(), subViewOp.getSource(),
-      sourceIndices,
-      getPermutationMapAttr(rewriter.getContext(), subViewOp,
-                            transferReadOp.getPermutationMap()),
-      transferReadOp.getPadding(),
-      /*mask=*/Value(), transferReadOp.getInBoundsAttr());
-}
-
-template <typename StoreOpTy>
-void StoreOpOfSubViewFolder<StoreOpTy>::replaceOp(
-    StoreOpTy storeOp, memref::SubViewOp subViewOp,
-    ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
-  rewriter.replaceOpWithNewOp<StoreOpTy>(storeOp, storeOp.getValue(),
-                                         subViewOp.getSource(), sourceIndices);
-}
-
-template <>
-void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
-    vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp,
-    ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
-  // TODO: support 0-d corner case.
-  if (transferWriteOp.getTransferRank() == 0)
-    return;
-  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-      transferWriteOp, transferWriteOp.getVector(), subViewOp.getSource(),
-      sourceIndices,
-      getPermutationMapAttr(rewriter.getContext(), subViewOp,
-                            transferWriteOp.getPermutationMap()),
-      transferWriteOp.getInBoundsAttr());
-}
-} // namespace
-
-template <typename OpTy>
-LogicalResult
-LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
-                                             PatternRewriter &rewriter) const {
-  auto subViewOp =
-      getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
-  if (!subViewOp)
-    return failure();
-
-  SmallVector<Value, 4> sourceIndices;
-  if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
-                                  loadOp.getIndices(), sourceIndices)))
-    return failure();
-
-  replaceOp(loadOp, subViewOp, sourceIndices, rewriter);
-  return success();
-}
-
-template <typename OpTy>
-LogicalResult
-StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
-                                              PatternRewriter &rewriter) const {
-  auto subViewOp =
-      getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
-  if (!subViewOp)
-    return failure();
-
-  SmallVector<Value, 4> sourceIndices;
-  if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
-                                  storeOp.getIndices(), sourceIndices)))
-    return failure();
-
-  replaceOp(storeOp, subViewOp, sourceIndices, rewriter);
-  return success();
-}
-
-void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) {
-  patterns.add<LoadOpOfSubViewFolder<AffineLoadOp>,
-               LoadOpOfSubViewFolder<memref::LoadOp>,
-               LoadOpOfSubViewFolder<vector::TransferReadOp>,
-               StoreOpOfSubViewFolder<AffineStoreOp>,
-               StoreOpOfSubViewFolder<memref::StoreOp>,
-               StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
-      patterns.getContext());
-}
-
-//===----------------------------------------------------------------------===//
-// Pass registration
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-struct FoldSubViewOpsPass final
-    : public FoldSubViewOpsBase<FoldSubViewOpsPass> {
-  void runOnOperation() override;
-};
-
-} // namespace
-
-void FoldSubViewOpsPass::runOnOperation() {
-  RewritePatternSet patterns(&getContext());
-  memref::populateFoldSubViewOpPatterns(patterns);
-  (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
-}
-
-std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() {
-  return std::make_unique<FoldSubViewOpsPass>();
-}

diff  --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 2d1bfc1e92a50..b9901b7af0ff1 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -8,6 +8,8 @@
 
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 
 int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
@@ -42,3 +44,26 @@ llvm::SmallVector<int64_t, 4> mlir::getI64SubArray(ArrayAttr arrayAttr,
     res.push_back((*it).getValue().getSExtValue());
   return res;
 }
+
+mlir::AffineExpr mlir::getLinearAffineExpr(ArrayRef<int64_t> basis,
+                                           mlir::Builder &b) {
+  AffineExpr resultExpr = b.getAffineDimExpr(0);
+  resultExpr = resultExpr * basis[0];
+  for (unsigned i = 1; i < basis.size(); i++)
+    resultExpr = resultExpr + b.getAffineDimExpr(i) * basis[i];
+  return resultExpr;
+}
+
+llvm::SmallVector<mlir::AffineExpr, 4>
+mlir::getDelinearizedAffineExpr(mlir::ArrayRef<int64_t> strides, Builder &b) {
+  AffineExpr resultExpr = b.getAffineDimExpr(0);
+  int64_t rank = strides.size();
+  SmallVector<AffineExpr, 4> vectorOffsets(rank);
+  vectorOffsets[0] = resultExpr.floorDiv(strides[0]);
+  resultExpr = resultExpr % strides[0];
+  for (unsigned i = 1; i < rank; i++) {
+    vectorOffsets[i] = resultExpr.floorDiv(strides[i]);
+    resultExpr = resultExpr % strides[i];
+  }
+  return vectorOffsets;
+}

diff  --git a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
similarity index 64%
rename from mlir/test/Dialect/MemRef/fold-subview-ops.mlir
rename to mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 28138e93aec9e..18c2b3f403e98 100644
--- a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -fold-memref-subview-ops -split-input-file %s -o - | FileCheck %s
+// RUN: mlir-opt -fold-memref-alias-ops -split-input-file %s -o - | FileCheck %s
 
 func.func @fold_static_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 {
   %0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
@@ -272,3 +272,154 @@ func.func @fold_static_stride_subview_with_affine_load_store(%arg0 : memref<12x3
   // CHECK-NEXT: return
   return %1 : f32
 }
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 * 6 + d1)>
+// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
+// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32 {
+func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> f32 {
+  %0 = memref.expand_shape %arg0 [[0, 1], [2]] : memref<12x32xf32> into memref<2x6x32xf32>
+  %1 = affine.load %0[%arg1, %arg2, %arg3] : memref<2x6x32xf32>
+  return %1 : f32
+}
+// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]](%[[ARG1]], %[[ARG2]])
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG3]]] : memref<12x32xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (d0 floordiv 6)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 mod 6)>
+// CHECK-LABEL: @fold_static_stride_subview_with_affine_load_store_collapse_shape
+// CHECK-SAME: (%[[ARG0:.*]]: memref<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg0 : memref<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
+  %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<2x6x32xf32> into memref<12x32xf32>
+  %1 = affine.load %0[%arg1, %arg2] : memref<12x32xf32>
+  return %1 : f32
+}
+// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]](%[[ARG1]])
+// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]](%[[ARG1]])
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<2x6x32xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1, d2) -> (d0 * 6 + d1 * 3 + d2)>
+// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_3d
+// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) -> f32 {
+func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: index) -> f32 {
+  %0 = memref.expand_shape %arg0 [[0, 1, 2], [3]] : memref<12x32xf32> into memref<2x2x3x32xf32>
+  %1 = affine.load %0[%arg1, %arg2, %arg3, %arg4] : memref<2x2x3x32xf32>
+  return %1 : f32
+}
+// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]](%[[ARG1]], %[[ARG2]], %[[ARG3]])
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG4]]] : memref<12x32xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
+func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
+  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
+  affine.for %arg3 = 0 to 1 {
+    affine.for %arg4 = 0 to 1024 {
+      affine.for %arg5 = 0 to 1020 {
+        affine.for %arg6 = 0 to 1 {
+          %1 = affine.load %0[%arg3, %arg4, %arg5, %arg6] : memref<1x1024x1024x1xf32>
+          affine.store %1, %arg1[%arg2] : memref<1xf32>
+        }
+      }
+    }
+  }
+  %2 = affine.load %arg1[%arg2] : memref<1xf32>
+  return %2 : f32
+}
+// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 1 {
+// CHECK-NEXT:  affine.for %[[ARG4:.*]] = 0 to 1024 {
+// CHECK-NEXT:   affine.for %[[ARG5:.*]] = 0 to 1020 {
+// CHECK-NEXT:    affine.for %[[ARG6:.*]] = 0 to 1 {
+// CHECK-NEXT:     %[[IDX1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])
+// CHECK-NEXT:     %[[IDX2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
+// CHECK-NEXT:     affine.load %[[ARG0]][%[[IDX1]], %[[IDX2]]] : memref<1024x1024xf32>
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1 + d0)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
+func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
+  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
+  affine.for %arg3 = 0 to 1 {
+    affine.for %arg4 = 0 to 1024 {
+      affine.for %arg5 = 0 to 1020 {
+        affine.for %arg6 = 0 to 1 {
+          %1 = affine.load %0[%arg3, %arg4 + %arg3, %arg5, %arg6] : memref<1x1024x1024x1xf32>
+          affine.store %1, %arg1[%arg2] : memref<1xf32>
+        }
+      }
+    }
+  }
+  %2 = affine.load %arg1[%arg2] : memref<1xf32>
+  return %2 : f32
+}
+// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 1 {
+// CHECK-NEXT:  affine.for %[[ARG4:.*]] = 0 to 1024 {
+// CHECK-NEXT:   affine.for %[[ARG5:.*]] = 0 to 1020 {
+// CHECK-NEXT:    affine.for %[[ARG6:.*]] = 0 to 1 {
+// CHECK-NEXT:      %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]], %[[ARG5]], %[[ARG6]])
+// CHECK-NEXT:      %[[TMP2:.*]] = affine.apply #[[$MAP1]](%[[ARG3]], %[[TMP1]])
+// CHECK-NEXT:      %[[TMP3:.*]] = affine.apply #map2(%[[ARG5]], %[[ARG6]])
+// CHECK-NEXT:      affine.load %[[ARG0]][%[[TMP2]], %[[TMP3]]] : memref<1024x1024xf32>
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
+func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
+  %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
+  %cst = arith.constant 0 : index
+  affine.for %arg3 = 0 to 1 {
+    affine.for %arg4 = 0 to 1024 {
+      affine.for %arg5 = 0 to 1020 {
+        affine.for %arg6 = 0 to 1 {
+          %1 = memref.load %0[%arg3, %cst, %arg5, %arg6] : memref<1x1024x1024x1xf32>
+          memref.store %1, %arg1[%arg2] : memref<1xf32>
+        }
+      }
+    }
+  }
+  %2 = memref.load %arg1[%arg2] : memref<1xf32>
+  return %2 : f32
+}
+// CHECK-NEXT:   %[[ZERO:.*]] = arith.constant 0 : index
+// CHECK-NEXT:   affine.for %[[ARG3:.*]] = 0 to 1 {
+// CHECK-NEXT:   affine.for %[[ARG4:.*]] = 0 to 1024 {
+// CHECK-NEXT:    affine.for %[[ARG5:.*]] = 0 to 1020 {
+// CHECK-NEXT:     affine.for %[[ARG6:.*]] = 0 to 1 {
+// CHECK-NEXT:      %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ZERO]])
+// CHECK-NEXT:      %[[TMP2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
+// CHECK-NEXT:      memref.load %[[ARG0]][%[[TMP1]], %[[TMP2]]] : memref<1024x1024xf32>
+
+// -----
+
+// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_collapse_shape_with_0d_result
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1xf32>, %[[ARG1:.*]]: memref<1xf32>)
+func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape_with_0d_result(%arg0: memref<1xf32>, %arg1: memref<1xf32>) -> memref<1xf32> {
+  %0 = memref.collapse_shape %arg0 [] : memref<1xf32> into memref<f32>
+  affine.for %arg2 = 0 to 3 {
+    %1 = affine.load %0[] : memref<f32>
+    affine.store %1, %arg1[0] : memref<1xf32>
+  }
+  return %arg1 : memref<1xf32>
+}
+// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : index
+// CHECK-NEXT: affine.for %{{.*}} = 0 to 3 {
+// CHECK-NEXT:   affine.load %[[ARG0]][%[[ZERO]]] : memref<1xf32>

diff  --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
index d942ef96f4c4f..8dbbd17ba0a08 100644
--- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
+++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
@@ -46,7 +46,7 @@ static LogicalResult runMLIRPasses(ModuleOp module) {
   applyPassManagerCLOptions(passManager);
 
   passManager.addPass(createGpuKernelOutliningPass());
-  passManager.addPass(memref::createFoldSubViewOpsPass());
+  passManager.addPass(memref::createFoldMemRefAliasOpsPass());
 
   passManager.addPass(createConvertGPUToSPIRVPass(/*mapMemorySpace=*/true));
   OpPassManager &modulePM = passManager.nest<spirv::ModuleOp>();


        


More information about the Mlir-commits mailing list