[Mlir-commits] [mlir] [mlir][MemRef] Migrate memref dialect alias op folding to interface (PR #187168)

Krzysztof Drewniak llvmlistbot at llvm.org
Tue Mar 17 18:14:47 PDT 2026


https://github.com/krzysz00 created https://github.com/llvm/llvm-project/pull/187168

This PR adds code to FoldMemRefAliasOps / --fold-memref-alias-ops to use the new IndexedMemoryAccessOpInterface and
IndexedMemCopyOpInterface and implement those operations for relevant operations in the memref dialect.

This is a reordering of the changes planned in #177014 and #177016 to make them more testable.

There are no behavior changes expected for how memref.load and memref.store behave within the alias ops folding pass, though support for new operations, like memref.prefetch, has been added.

Some error messages have been updated because certain laws of memref.load/memref.store have been moved to IndexedAccessOpInterface.

Assisted-by: Claude 4.6 (helped deal with some of the boilerplate in the rewrite patterns and with extracting the patch)

>From 3b24753e3c23040d9672ada7c9d9b3074c74a828 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Wed, 18 Mar 2026 00:51:48 +0000
Subject: [PATCH] [mlir][MemRef] Migrate memref dialect alias op folding to
 interface

This PR adds code to FoldMemRefAliasOps / --fold-memref-alias-ops to
use the new IndexedMemoryAccessOpInterface and
IndexedMemCopyOpInterface and implement those operations for relevant
operations in the memref dialect.

This is a reordering of the changes planned in #177014 and #177016 to
make them more testable.

There are no behavior changes expected for how memref.load and
memref.store behave within the alias ops folding pass, though support
for new operations, like memref.prefetch, has been added.

Assisted-by: Claude 4.6 (helped deal with some of the boilerplate in
the rewrite patterns and with extracting the patch)
---
 mlir/include/mlir/Dialect/MemRef/IR/MemRef.h  |   1 +
 .../mlir/Dialect/MemRef/IR/MemRefOps.td       |  44 ++-
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      |  99 ++++-
 .../MemRef/Transforms/FoldMemRefAliasOps.cpp  | 351 ++++++++++++++++--
 mlir/test/Dialect/Linalg/invalid.mlir         |   4 +-
 .../Dialect/MemRef/fold-memref-alias-ops.mlir | 145 ++++++++
 mlir/test/Dialect/MemRef/invalid.mlir         |   4 +-
 7 files changed, 593 insertions(+), 55 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index b7abcdea10a2a..8653eca0072b6 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/Interfaces/AlignmentAttrInterface.h"
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 70180c101407a..9dba4d790d631 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -11,6 +11,7 @@
 
 include "mlir/Dialect/Arith/IR/ArithBase.td"
 include "mlir/Dialect/MemRef/IR/MemRefBase.td"
+include "mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.td"
 include "mlir/Interfaces/AlignmentAttrInterface.td"
 include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
@@ -699,7 +700,8 @@ def MemRef_DimOp : MemRef_Op<"dim", [
 // DmaStartOp
 //===----------------------------------------------------------------------===//
 
-def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
+def MemRef_DmaStartOp : MemRef_Op<"dma_start", [
+    IndexedMemCopyOpInterface]> {
   let summary = "non-blocking DMA operation that starts a transfer";
   let description = [{
     Syntax:
@@ -778,6 +780,13 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
       return {(*this)->operand_begin() + 1,
               (*this)->operand_begin() + 1 + getSrcMemRefRank()};
     }
+    // Alias to getSrcMemRef() for uniformity with other DMA-like ops.
+    ::mlir::TypedValue<::mlir::MemRefType> getSrc() {
+      // This can be called before op verifaciton, so we guarad against bad variadics.
+      if ((*this)->getOperands().empty())
+        return nullptr;
+      return ::llvm::dyn_cast<::mlir::TypedValue<::mlir::MemRefType>>(getSrcMemRef());
+    }
 
     // Returns the destination MemRefType for this DMA operations.
     Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
@@ -786,6 +795,16 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
     unsigned getDstMemRefRank() {
       return ::llvm::cast<MemRefType>(getDstMemRef().getType()).getRank();
     }
+    // Alias to getDstMemRef() for uniformity with other DMA-like ops.
+    ::mlir::TypedValue<::mlir::MemRefType> getDst() {
+      // Guardrails since this runs before the op verifier and the DMA op doesn't use ODS to define operands.
+      if (!getSrc())
+        return nullptr;
+      if ((*this)->getNumOperands() < (1 + getSrcMemRefRank() + 1))
+        return nullptr;
+      return ::llvm::dyn_cast<::mlir::TypedValue<::mlir::MemRefType>>(getDstMemRef());
+    }
+
     unsigned getSrcMemorySpace() {
       return ::llvm::cast<MemRefType>(getSrcMemRef().getType()).getMemorySpaceAsInt();
     }
@@ -875,6 +894,10 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
       effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
                            SideEffects::DefaultResource::get());
     }
+
+    void setMemrefsAndIndices(RewriterBase& rewriter,
+      Value newSrc, ValueRange newSrcIndices,
+      Value newDst, ValueRange newDstIndices);
   }];
   let hasCustomAssemblyFormat = 1;
   let hasFolder = 1;
@@ -1066,7 +1089,8 @@ def GenericAtomicRMWOp : MemRef_Op<"generic_atomic_rmw", [
       SingleBlockImplicitTerminator<"AtomicYieldOp">,
       TypesMatchWith<"result type matches element type of memref",
                      "memref", "result",
-                     "::llvm::cast<MemRefType>($_self).getElementType()">
+                     "::llvm::cast<MemRefType>($_self).getElementType()">,
+      DeclareOpInterfaceMethods<IndexedAccessOpInterface, ["getAccessedMemref"]>,
     ]> {
   let summary = "atomic read-modify-write operation with a region";
   let description = [{
@@ -1243,7 +1267,8 @@ def LoadOp : MemRef_Op<"load",
       DeclareOpInterfaceMethods<AlignmentAttrOpInterface>,
       DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
       DeclareOpInterfaceMethods<PromotableMemOpInterface>,
-      DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
+      DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
+      DeclareOpInterfaceMethods<IndexedAccessOpInterface, ["getAccessedMemref"]>]> {
   let summary = "load operation";
   let description = [{
     The `load` op reads an element from a memref at the specified indices.
@@ -1321,7 +1346,6 @@ def LoadOp : MemRef_Op<"load",
   }];
 
   let hasFolder = 1;
-  let hasVerifier = 1;
 
   let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)";
 }
@@ -1404,7 +1428,10 @@ def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
 // PrefetchOp
 //===----------------------------------------------------------------------===//
 
-def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
+def MemRef_PrefetchOp : MemRef_Op<"prefetch", [
+      DeclareOpInterfaceMethods<IndexedAccessOpInterface,
+                                ["getAccessedMemref", "getAccessedType"]>
+    ]> {
   let summary = "prefetch operation";
   let description = [{
     The "prefetch" op prefetches data from a memref location described with
@@ -2021,7 +2048,8 @@ def MemRef_StoreOp : MemRef_Op<"store",
       DeclareOpInterfaceMethods<AlignmentAttrOpInterface>,
       DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
       DeclareOpInterfaceMethods<PromotableMemOpInterface>,
-      DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
+      DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
+      DeclareOpInterfaceMethods<IndexedAccessOpInterface, ["getAccessedMemref"]>]> {
   let summary = "store operation";
   let description = [{
     The `store` op stores an element into a memref at the specified indices.
@@ -2084,7 +2112,6 @@ def MemRef_StoreOp : MemRef_Op<"store",
   }];
 
   let hasFolder = 1;
-  let hasVerifier = 1;
 
   let assemblyFormat = [{
     $value `,` $memref `[` $indices `]` attr-dict `:` type($memref)
@@ -2494,7 +2521,8 @@ def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
       AllTypesMatch<["value", "result"]>,
       TypesMatchWith<"value type matches element type of memref",
                      "memref", "value",
-                     "::llvm::cast<MemRefType>($_self).getElementType()">
+                     "::llvm::cast<MemRefType>($_self).getElementType()">,
+      DeclareOpInterfaceMethods<IndexedAccessOpInterface, ["getAccessedMemref"]>
     ]> {
   let summary = "atomic read-modify-write operation";
   let description = [{
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 404b2aacf1450..b6d18dd28af26 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1407,6 +1407,26 @@ LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
   return foldMemRefCast(*this);
 }
 
+void DmaStartOp::setMemrefsAndIndices(RewriterBase &rewriter, Value newSrc,
+                                      ValueRange newSrcIndices, Value newDst,
+                                      ValueRange newDstIndices) {
+  /// dma_start has special handling for variadic rank
+  SmallVector<Value> newOperands;
+  newOperands.push_back(newSrc);
+  llvm::append_range(newOperands, newSrcIndices);
+  newOperands.push_back(newDst);
+  llvm::append_range(newOperands, newDstIndices);
+  newOperands.push_back(getNumElements());
+  newOperands.push_back(getTagMemRef());
+  llvm::append_range(newOperands, getTagIndices());
+  if (isStrided()) {
+    newOperands.push_back(getStride());
+    newOperands.push_back(getNumElementsPerStride());
+  }
+
+  rewriter.modifyOpInPlace(*this, [&]() { (*this)->setOperands(newOperands); });
+}
+
 // ---------------------------------------------------------------------------
 // DmaWaitOp
 // ---------------------------------------------------------------------------
@@ -1635,6 +1655,19 @@ void GenericAtomicRMWOp::print(OpAsmPrinter &p) {
   p.printOptionalAttrDict((*this)->getAttrs());
 }
 
+TypedValue<MemRefType> GenericAtomicRMWOp::getAccessedMemref() {
+  return getMemref();
+}
+
+std::optional<SmallVector<Value>> GenericAtomicRMWOp::updateMemrefAndIndices(
+    RewriterBase &rewriter, Value newMemref, ValueRange newIndices) {
+  rewriter.modifyOpInPlace(*this, [&]() {
+    getMemrefMutable().assign(newMemref);
+    getIndicesMutable().assign(newIndices);
+  });
+  return std::nullopt;
+}
+
 //===----------------------------------------------------------------------===//
 // AtomicYieldOp
 //===----------------------------------------------------------------------===//
@@ -1770,14 +1803,6 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 // LoadOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult LoadOp::verify() {
-  if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
-    return emitOpError("incorrect number of indices for load, expected ")
-           << getMemRefType().getRank() << " but got " << getIndices().size();
-  }
-  return success();
-}
-
 OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
   /// load(memrefcast) -> load
   if (succeeded(foldMemRefCast(*this)))
@@ -1802,6 +1827,18 @@ OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
   return splatAttr.getSplatValue<Attribute>();
 }
 
+TypedValue<MemRefType> LoadOp::getAccessedMemref() { return getMemref(); }
+
+std::optional<SmallVector<Value>>
+LoadOp::updateMemrefAndIndices(RewriterBase &rewriter, Value newMemref,
+                               ValueRange newIndices) {
+  rewriter.modifyOpInPlace(*this, [&]() {
+    getMemrefMutable().assign(newMemref);
+    getIndicesMutable().assign(newIndices);
+  });
+  return std::nullopt;
+}
+
 FailureOr<std::optional<SmallVector<Value>>>
 LoadOp::bubbleDownCasts(OpBuilder &builder) {
   return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
@@ -1945,6 +1982,18 @@ LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
   return foldMemRefCast(*this);
 }
 
+TypedValue<MemRefType> PrefetchOp::getAccessedMemref() { return getMemref(); }
+
+std::optional<SmallVector<Value>>
+PrefetchOp::updateMemrefAndIndices(RewriterBase &rewriter, Value newMemref,
+                                   ValueRange newIndices) {
+  rewriter.modifyOpInPlace(*this, [&]() {
+    getMemrefMutable().assign(newMemref);
+    getIndicesMutable().assign(newIndices);
+  });
+  return std::nullopt;
+}
+
 //===----------------------------------------------------------------------===//
 // RankOp
 //===----------------------------------------------------------------------===//
@@ -2970,19 +3019,24 @@ ReshapeOp::bubbleDownCasts(OpBuilder &builder) {
 // StoreOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult StoreOp::verify() {
-  if (getNumOperands() != 2 + getMemRefType().getRank())
-    return emitOpError("store index operand count not equal to memref rank");
-
-  return success();
-}
-
 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
                             SmallVectorImpl<OpFoldResult> &results) {
   /// store(memrefcast) -> store
   return foldMemRefCast(*this, getValueToStore());
 }
 
+TypedValue<MemRefType> StoreOp::getAccessedMemref() { return getMemref(); }
+
+std::optional<SmallVector<Value>>
+StoreOp::updateMemrefAndIndices(RewriterBase &rewriter, Value newMemref,
+                                ValueRange newIndices) {
+  rewriter.modifyOpInPlace(*this, [&]() {
+    getMemrefMutable().assign(newMemref);
+    getIndicesMutable().assign(newIndices);
+  });
+  return std::nullopt;
+}
+
 FailureOr<std::optional<SmallVector<Value>>>
 StoreOp::bubbleDownCasts(OpBuilder &builder) {
   return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
@@ -3971,9 +4025,6 @@ ViewOp::bubbleDownCasts(OpBuilder &builder) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AtomicRMWOp::verify() {
-  if (getMemRefType().getRank() != getNumOperands() - 2)
-    return emitOpError(
-        "expects the number of subscripts to be equal to memref rank");
   switch (getKind()) {
   case arith::AtomicRMWKind::addf:
   case arith::AtomicRMWKind::maximumf:
@@ -4017,6 +4068,18 @@ AtomicRMWOp::bubbleDownCasts(OpBuilder &builder) {
                                                             getResult());
 }
 
+TypedValue<MemRefType> AtomicRMWOp::getAccessedMemref() { return getMemref(); }
+
+std::optional<SmallVector<Value>>
+AtomicRMWOp::updateMemrefAndIndices(RewriterBase &rewriter, Value newMemref,
+                                    ValueRange newIndices) {
+  rewriter.modifyOpInPlace(*this, [&]() {
+    getMemrefMutable().assign(newMemref);
+    getIndicesMutable().assign(newIndices);
+  });
+  return std::nullopt;
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 06c3392cd6732..9fbb8f2d5892a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -15,17 +15,21 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
+#include <cstdint>
 
 #define DEBUG_TYPE "fold-memref-alias-ops"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -43,6 +47,25 @@ using namespace mlir;
 // Utility functions
 //===----------------------------------------------------------------------===//
 
+/// Deterimine if the last N indices of `reassocitaion` are trivial - that is,
+/// check if they all contain exactly one dimension to collape/expand into.
+static bool
+hasTrivialReassociationSuffix(ArrayRef<ReassociationIndices> reassocs,
+                              int64_t n) {
+  if (n <= 0)
+    return true;
+  return llvm::all_of(
+      reassocs.take_back(n),
+      [&](const ReassociationIndices &indices) { return indices.size() == 1; });
+}
+
+static bool hasTrailingUnitStrides(memref::SubViewOp subview, int64_t n) {
+  if (n <= 0)
+    return true;
+  return llvm::all_of(subview.getStaticStrides().take_back(n),
+                      [](int64_t s) { return s == 1; });
+}
+
 /// Helpers to access the memref operand for each op.
 template <typename LoadOrStoreOpTy>
 static Value getMemRefOperand(LoadOrStoreOpTy op) {
@@ -195,6 +218,82 @@ class NVGPUAsyncCopyOpSubViewOpFolder final
   LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
                                 PatternRewriter &rewriter) const override;
 };
+
+/// Merges subview operations with load/store like operations unless such a
+/// merger would cause the strides between dimensions accessed by that operaton
+/// to change.
+struct AccessOpOfSubViewOpFolder final
+    : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Merge a memref.expand_shape operation with an operation that accesses a
+/// memref by index unless that operation accesss more than one dimension of
+/// memory and any dimension other than the outermost dimension accessed this
+/// way would be merged. This prevents issuses from arising with, say, a
+/// vector.load of a 4x2 vector having the two trailing dimensions of the access
+/// get merged.
+struct AccessOpOfExpandShapeOpFolder final
+    : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Merges an operation that accesses a memref by index with a
+/// memref.collapse_shape, unless this would break apart a dimension other than
+/// the outermost one that an operation accesses. This prevents, for example,
+/// transforming a load of a 3x8 vector from a 6x8 memref into a load
+/// from a 3x4x2 memref (as this would require special handling and could lead
+/// to invalid IR if that higher-dimensional memref comes from a subview) but
+/// does permit turning a load of a length-8 vector from a 3x8 memref into a
+/// load from a 3x2x8 one.
+struct AccessOpOfCollapseShapeOpFolder final
+    : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Merges memref.subview operations present on the source or destination
+/// operands of indexed memory copy operations (DMA operations) into those
+/// operations. This is perfromed unconditionally, since folding in a subview
+/// cannot change the starting position of the copy, which is what the
+/// memref/index pair represent in DMA operations.
+struct IndexedMemCopyOpOfSubViewOpFolder final
+    : OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Merges memref.expand_shape operations that are present on the source or
+/// destination of an indexed memory copy/DMA into the memref/index arguments of
+/// that DMA. As with subviews, this can be done unconditionally.
+struct IndexedMemCopyOpOfExpandShapeOpFolder final
+    : OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Merges memref.collapse_shape operations that are present on the source or
+/// destination of an indexed memory copy/DMA into the memref/index arguments of
+/// that DMA. As with subviews, this can be done unconditionally.
+struct IndexedMemCopyOpOfCollapseShapeOpFolder final
+    : OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
+                                PatternRewriter &rewriter) const override;
+};
 } // namespace
 
 template <typename XferOp>
@@ -516,6 +615,207 @@ LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
   return success();
 }
 
+LogicalResult
+AccessOpOfSubViewOpFolder::matchAndRewrite(memref::IndexedAccessOpInterface op,
+                                           PatternRewriter &rewriter) const {
+  auto subview = op.getAccessedMemref().getDefiningOp<memref::SubViewOp>();
+  if (!subview)
+    return rewriter.notifyMatchFailure(op, "not accessing a subview");
+
+  SmallVector<int64_t> accessedShape = op.getAccessedShape();
+  // Note the subtle difference between accesedShape = {1} and accessedShape =
+  // {} here. The former prevents us from fdolding in a subview that doesn't
+  // have a unit stride on the final dimension, while the latter does not (since
+  // it indices scalar accesss).
+  int64_t accessedDims = accessedShape.size();
+  if (!hasTrailingUnitStrides(subview, accessedDims))
+    return rewriter.notifyMatchFailure(
+        op, "non-unit stride on accessed dimensions");
+
+  llvm::SmallBitVector droppedDims = subview.getDroppedDims();
+  int64_t sourceRank = subview.getSourceType().getRank();
+
+  // Ignore outermost access dimension - we only care about dropped dimensions
+  // between the accessed op's results, as those could break the accessing op's
+  // sematics.
+  int64_t secondAccessedDim = sourceRank - (accessedDims - 1);
+  if (secondAccessedDim < sourceRank) {
+    for (int64_t d : llvm::seq(secondAccessedDim, sourceRank)) {
+      if (droppedDims.test(d))
+        return rewriter.notifyMatchFailure(
+            op, "reintroducing dropped dimension " + Twine(d) +
+                    " would break access op semantics");
+    }
+  }
+
+  SmallVector<Value> sourceIndices;
+  affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+      rewriter, op.getLoc(), subview.getMixedOffsets(),
+      subview.getMixedStrides(), droppedDims, op.getIndices(), sourceIndices);
+
+  std::optional<SmallVector<Value>> newValues =
+      op.updateMemrefAndIndices(rewriter, subview.getSource(), sourceIndices);
+  if (newValues)
+    rewriter.replaceOp(op, *newValues);
+  return success();
+}
+
+LogicalResult AccessOpOfExpandShapeOpFolder::matchAndRewrite(
+    memref::IndexedAccessOpInterface op, PatternRewriter &rewriter) const {
+  auto expand = op.getAccessedMemref().getDefiningOp<memref::ExpandShapeOp>();
+  if (!expand)
+    return rewriter.notifyMatchFailure(op, "not accessing an expand_shape");
+
+  SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
+  ArrayRef<int64_t> accessedShape = rawAccessedShape;
+  // Cut off the leading dimension, since we don't care about monifying its
+  // strides.
+  if (!accessedShape.empty())
+    accessedShape = accessedShape.drop_front();
+
+  auto reassocs = expand.getReassociationIndices();
+  if (!hasTrivialReassociationSuffix(reassocs, accessedShape.size()))
+    return rewriter.notifyMatchFailure(
+        op,
+        "expand_shape folding would merge semanvtically important dimensions");
+
+  SmallVector<Value> sourceIndices;
+  memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, expand,
+                                          op.getIndices(), sourceIndices,
+                                          op.hasInboundsIndices());
+
+  std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
+      rewriter, expand.getViewSource(), sourceIndices);
+  if (newValues)
+    rewriter.replaceOp(op, *newValues);
+  return success();
+}
+
+LogicalResult AccessOpOfCollapseShapeOpFolder::matchAndRewrite(
+    memref::IndexedAccessOpInterface op, PatternRewriter &rewriter) const {
+  auto collapse =
+      op.getAccessedMemref().getDefiningOp<memref::CollapseShapeOp>();
+  if (!collapse)
+    return rewriter.notifyMatchFailure(op, "not accessing a collapse_shape");
+
+  SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
+  ArrayRef<int64_t> accessedShape = rawAccessedShape;
+  // Cut off the leading dimension, since we don't care about its strides being
+  // modified and we know that the dimensions within its reassociation group, if
+  // it's non-trivial, must be contiguous.
+  if (!accessedShape.empty())
+    accessedShape = accessedShape.drop_front();
+
+  auto reassocs = collapse.getReassociationIndices();
+  if (!hasTrivialReassociationSuffix(reassocs, accessedShape.size()))
+    return rewriter.notifyMatchFailure(op,
+                                       "collapse_shape folding would merge "
+                                       "semanvtically important dimensions");
+
+  SmallVector<Value> sourceIndices;
+  memref::resolveSourceIndicesCollapseShape(op.getLoc(), rewriter, collapse,
+                                            op.getIndices(), sourceIndices);
+
+  std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
+      rewriter, collapse.getViewSource(), sourceIndices);
+  if (newValues)
+    rewriter.replaceOp(op, *newValues);
+  return success();
+}
+
+LogicalResult IndexedMemCopyOpOfSubViewOpFolder::matchAndRewrite(
+    memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const {
+  auto srcSubview = op.getSrc().getDefiningOp<memref::SubViewOp>();
+  auto dstSubview = op.getDst().getDefiningOp<memref::SubViewOp>();
+  if (!srcSubview && !dstSubview)
+    return rewriter.notifyMatchFailure(
+        op, "no subviews found on indexed copy inputs");
+
+  Value newSrc = op.getSrc();
+  SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
+  Value newDst = op.getDst();
+  SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
+  if (srcSubview) {
+    newSrc = srcSubview.getSource();
+    newSrcIndices.clear();
+    affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+        rewriter, op.getLoc(), srcSubview.getMixedOffsets(),
+        srcSubview.getMixedStrides(), srcSubview.getDroppedDims(),
+        op.getSrcIndices(), newSrcIndices);
+  }
+  if (dstSubview) {
+    newDst = dstSubview.getSource();
+    newDstIndices.clear();
+    affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+        rewriter, op.getLoc(), dstSubview.getMixedOffsets(),
+        dstSubview.getMixedStrides(), dstSubview.getDroppedDims(),
+        op.getDstIndices(), newDstIndices);
+  }
+  op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
+                          newDstIndices);
+  return success();
+}
+
+LogicalResult IndexedMemCopyOpOfExpandShapeOpFolder::matchAndRewrite(
+    memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const {
+  auto srcExpand = op.getSrc().getDefiningOp<memref::ExpandShapeOp>();
+  auto dstExpand = op.getDst().getDefiningOp<memref::ExpandShapeOp>();
+  if (!srcExpand && !dstExpand)
+    return rewriter.notifyMatchFailure(
+        op, "no expand_shapes found on indexed copy inputs");
+
+  Value newSrc = op.getSrc();
+  SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
+  Value newDst = op.getDst();
+  SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
+  if (srcExpand) {
+    newSrc = srcExpand.getViewSource();
+    newSrcIndices.clear();
+    memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, srcExpand,
+                                            op.getSrcIndices(), newSrcIndices,
+                                            /*startsInbounds=*/true);
+  }
+  if (dstExpand) {
+    newDst = dstExpand.getViewSource();
+    newDstIndices.clear();
+    memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, dstExpand,
+                                            op.getDstIndices(), newDstIndices,
+                                            /*startsInbounds=*/true);
+  }
+  op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
+                          newDstIndices);
+  return success();
+}
+
+LogicalResult IndexedMemCopyOpOfCollapseShapeOpFolder::matchAndRewrite(
+    memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const {
+  auto srcCollapse = op.getSrc().getDefiningOp<memref::CollapseShapeOp>();
+  auto dstCollapse = op.getDst().getDefiningOp<memref::CollapseShapeOp>();
+  if (!srcCollapse && !dstCollapse)
+    return rewriter.notifyMatchFailure(
+        op, "no collapse_shapes found on indexed copy inputs");
+
+  Value newSrc = op.getSrc();
+  SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
+  Value newDst = op.getDst();
+  SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
+  if (srcCollapse) {
+    newSrc = srcCollapse.getViewSource();
+    newSrcIndices.clear();
+    memref::resolveSourceIndicesCollapseShape(
+        op.getLoc(), rewriter, srcCollapse, op.getSrcIndices(), newSrcIndices);
+  }
+  if (dstCollapse) {
+    newDst = dstCollapse.getViewSource();
+    newDstIndices.clear();
+    memref::resolveSourceIndicesCollapseShape(
+        op.getLoc(), rewriter, dstCollapse, op.getDstIndices(), newDstIndices);
+  }
+  op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
+                          newDstIndices);
+  return success();
+}
+
 LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
     nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const {
 
@@ -568,31 +868,32 @@ LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
 }
 
 void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
-  patterns.add<LoadOpOfSubViewOpFolder<memref::LoadOp>,
-               LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
-               LoadOpOfSubViewOpFolder<vector::LoadOp>,
-               LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
-               LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
-               LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
-               StoreOpOfSubViewOpFolder<memref::StoreOp>,
-               StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
-               StoreOpOfSubViewOpFolder<vector::StoreOp>,
-               StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
-               StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
-               LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
-               LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
-               LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
-               LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>,
-               StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
-               StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
-               StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
-               LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
-               LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
-               LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
-               StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
-               StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
-               StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
-               SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
+  patterns.add<
+      // Interface-based patterns to which we will be migrating.
+      AccessOpOfSubViewOpFolder, AccessOpOfExpandShapeOpFolder,
+      AccessOpOfCollapseShapeOpFolder, IndexedMemCopyOpOfSubViewOpFolder,
+      IndexedMemCopyOpOfExpandShapeOpFolder,
+      IndexedMemCopyOpOfCollapseShapeOpFolder,
+      // The old way of doing things. Don't add more of these.
+      LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
+      LoadOpOfSubViewOpFolder<vector::LoadOp>,
+      LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
+      LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
+      LoadOpOfSubViewOpFolder<gpu::SubgroupMmaLoadMatrixOp>,
+      StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
+      StoreOpOfSubViewOpFolder<vector::StoreOp>,
+      StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
+      StoreOpOfSubViewOpFolder<gpu::SubgroupMmaStoreMatrixOp>,
+      LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
+      LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
+      LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>,
+      StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
+      StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
+      LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
+      LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
+      StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
+      StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
+      SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
       patterns.getContext());
 }
 
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 9500d00a5e647..1d69870825536 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics
 
 func.func @load_number_of_indices(%v : memref<f32>) {
-  // expected-error @+2 {{incorrect number of indices for load}}
+  // expected-error @+2 {{invalid number of indices for accessed memref, expected 0 but got 1}}
   %c0 = arith.constant 0 : index
   memref.load %v[%c0] : memref<f32>
 }
@@ -9,7 +9,7 @@ func.func @load_number_of_indices(%v : memref<f32>) {
 // -----
 
 func.func @store_number_of_indices(%v : memref<f32>) {
-  // expected-error @+3 {{store index operand count not equal to memref rank}}
+  // expected-error @+3 {{invalid number of indices for accessed memref, expected 0 but got 1}}
   %c0 = arith.constant 0 : index
   %f0 = arith.constant 0.0 : f32
   memref.store %f0, %v[%c0] : memref<f32>
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 3f77a0553fff9..114ba86cda718 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -529,6 +529,20 @@ func.func @fold_store_keep_nontemporal(%arg0 : memref<12x32xf32>, %arg1 : index,
 
 // -----
 
+func.func @fold_prefetch_expand_shape(%src: memref<32xf32>, %i0: index, %i1: index) {
+  %expand = memref.expand_shape %src [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+  memref.prefetch %expand[%i0, %i1], read, locality<2>, data : memref<4x8xf32>
+  return
+}
+
+//      CHECK: func.func @fold_prefetch_expand_shape
+// CHECK-SAME: (%[[SRC:.+]]: memref<32xf32>, %[[I0:.+]]: index, %[[I1:.+]]: index)
+//      CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[I0]], %[[I1]]] by (4, 8)
+//      CHECK: memref.prefetch %[[SRC]][%[[LIN]]], read, locality<2>, data : memref<32xf32>
+//  CHECK-NOT: memref.expand_shape
+
+// -----
+
 func.func @fold_gpu_subgroup_mma_load_matrix_1d(%src: memref<?xvector<4xf32>>, %offset: index, %i: index) -> !gpu.mma_matrix<16x16xf16, "COp"> {
   %subview = memref.subview %src[%offset] [81920] [1] : memref<?xvector<4xf32>> to memref<81920xvector<4xf32>, strided<[1], offset: ?>>
   %matrix = gpu.subgroup_mma_load_matrix %subview[%i] {leadDimension = 160 : index} : memref<81920xvector<4xf32>, strided<[1], offset: ?>> -> !gpu.mma_matrix<16x16xf16, "COp">
@@ -968,3 +982,134 @@ func.func @fold_vector_maskedstore_collapse_shape(
 //  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
 //       CHECK:   %[[IDXS:.*]]:2 = affine.delinearize_index %[[ARG1]] into (4, 8)
 //       CHECK:   vector.maskedstore %[[ARG0]][%[[IDXS]]#0, %[[IDXS]]#1], %[[ARG3]], %[[ARG4]]
+
+// -----
+
+func.func @fold_dma_start_subview_src(
+    %src : memref<128x64xf32>, %dst : memref<32xf32, 1>, %tag : memref<1xi32>,
+    %off0 : index, %off1 : index) {
+  %c0 = arith.constant 0 : index
+  %num_elements = arith.constant 32 : index
+  %subview = memref.subview %src[%off0, %off1][32, 32][1, 1] : memref<128x64xf32> to memref<32x32xf32, strided<[64, 1], offset: ?>>
+  memref.dma_start %subview[%c0, %c0], %dst[%c0], %num_elements, %tag[%c0] : memref<32x32xf32, strided<[64, 1], offset: ?>>, memref<32xf32, 1>, memref<1xi32>
+  return
+}
+
+// CHECK-LABEL: func @fold_dma_start_subview_src
+// CHECK-SAME:   %[[SRC:[a-zA-Z0-9_]+]]: memref<128x64xf32>
+// CHECK-SAME:   %[[DST:[a-zA-Z0-9_]+]]: memref<32xf32, 1>
+// CHECK-SAME:   %[[TAG:[a-zA-Z0-9_]+]]: memref<1xi32>
+// CHECK-SAME:   %[[OFF0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[OFF1:[a-zA-Z0-9_]+]]: index
+//  CHECK-DAG:   %[[C0:.*]] = arith.constant 0
+//  CHECK-DAG:   %[[NUM:.*]] = arith.constant 32
+//      CHECK:   memref.dma_start %[[SRC]][%[[OFF0]], %[[OFF1]]], %[[DST]][%[[C0]]], %[[NUM]], %[[TAG]][%[[C0]]]
+
+// -----
+
+func.func @fold_dma_start_subview_dst(
+    %src : memref<32xf32>, %dst : memref<128x64xf32, 1>, %tag : memref<1xi32>,
+    %off0 : index, %off1 : index) {
+  %c0 = arith.constant 0 : index
+  %num_elements = arith.constant 32 : index
+  %subview = memref.subview %dst[%off0, %off1][32, 32][1, 1] : memref<128x64xf32, 1> to memref<32x32xf32, strided<[64, 1], offset: ?>, 1>
+  memref.dma_start %src[%c0], %subview[%c0, %c0], %num_elements, %tag[%c0] : memref<32xf32>, memref<32x32xf32, strided<[64, 1], offset: ?>, 1>, memref<1xi32>
+  return
+}
+// CHECK-LABEL: func @fold_dma_start_subview_dst
+// CHECK-SAME:   %[[SRC:[a-zA-Z0-9_]+]]: memref<32xf32>
+// CHECK-SAME:   %[[DST:[a-zA-Z0-9_]+]]: memref<128x64xf32, 1>
+// CHECK-SAME:   %[[TAG:[a-zA-Z0-9_]+]]: memref<1xi32>
+// CHECK-SAME:   %[[OFF0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME:   %[[OFF1:[a-zA-Z0-9_]+]]: index
+//  CHECK-DAG:   %[[C0:.*]] = arith.constant 0
+//  CHECK-DAG:   %[[NUM:.*]] = arith.constant 32
+//      CHECK:   memref.dma_start %[[SRC]][%[[C0]]], %[[DST]][%[[OFF0]], %[[OFF1]]], %[[NUM]], %[[TAG]][%[[C0]]]
+
+// -----
+
+func.func @fold_dma_start_expand_shape_src(
+    %src : memref<32xf32>, %dst : memref<8xf32, 1>, %tag : memref<1xi32>,
+    %idx : index) {
+  %c0 = arith.constant 0 : index
+  %num_elements = arith.constant 8 : index
+  %expand = memref.expand_shape %src [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+  memref.dma_start %expand[%idx, %c0], %dst[%c0], %num_elements, %tag[%c0] : memref<4x8xf32>, memref<8xf32, 1>, memref<1xi32>
+  return
+}
+
+// CHECK-LABEL: func @fold_dma_start_expand_shape_src
+// CHECK-SAME:   %[[SRC:[a-zA-Z0-9_]+]]: memref<32xf32>
+// CHECK-SAME:   %[[DST:[a-zA-Z0-9_]+]]: memref<8xf32, 1>
+// CHECK-SAME:   %[[TAG:[a-zA-Z0-9_]+]]: memref<1xi32>
+// CHECK-SAME:   %[[IDX:[a-zA-Z0-9_]+]]: index
+//  CHECK-DAG:   %[[C0:.*]] = arith.constant 0
+//  CHECK-DAG:   %[[NUM:.*]] = arith.constant 8
+//      CHECK:   %[[I:.*]] = affine.linearize_index disjoint [%[[IDX]], %[[C0]]] by (4, 8)
+//      CHECK:   memref.dma_start %[[SRC]][%[[I]]], %[[DST]][%[[C0]]], %[[NUM]], %[[TAG]][%[[C0]]]
+
+// -----
+
+func.func @fold_dma_start_expand_shape_dst(
+    %src : memref<8xf32>, %dst : memref<32xf32, 1>, %tag : memref<1xi32>,
+    %idx : index) {
+  %c0 = arith.constant 0 : index
+  %num_elements = arith.constant 8 : index
+  %expand = memref.expand_shape %dst [[0, 1]] output_shape [4, 8] : memref<32xf32, 1> into memref<4x8xf32, 1>
+  memref.dma_start %src[%c0], %expand[%idx, %c0], %num_elements, %tag[%c0] : memref<8xf32>, memref<4x8xf32, 1>, memref<1xi32>
+  return
+}
+
+// CHECK-LABEL: func @fold_dma_start_expand_shape_dst
+// CHECK-SAME:   %[[SRC:[a-zA-Z0-9_]+]]: memref<8xf32>
+// CHECK-SAME:   %[[DST:[a-zA-Z0-9_]+]]: memref<32xf32, 1>
+// CHECK-SAME:   %[[TAG:[a-zA-Z0-9_]+]]: memref<1xi32>
+// CHECK-SAME:   %[[IDX:[a-zA-Z0-9_]+]]: index
+//  CHECK-DAG:   %[[C0:.*]] = arith.constant 0
+//  CHECK-DAG:   %[[NUM:.*]] = arith.constant 8
+//      CHECK:   %[[I:.*]] = affine.linearize_index disjoint [%[[IDX]], %[[C0]]] by (4, 8)
+//      CHECK:   memref.dma_start %[[SRC]][%[[C0]]], %[[DST]][%[[I]]], %[[NUM]], %[[TAG]][%[[C0]]]
+
+// -----
+
+func.func @fold_dma_start_collapse_shape_src(
+    %src : memref<4x8xf32>, %dst : memref<8xf32, 1>, %tag : memref<1xi32>,
+    %idx : index) {
+  %c0 = arith.constant 0 : index
+  %num_elements = arith.constant 8 : index
+  %collapse = memref.collapse_shape %src [[0, 1]] : memref<4x8xf32> into memref<32xf32>
+  memref.dma_start %collapse[%idx], %dst[%c0], %num_elements, %tag[%c0] : memref<32xf32>, memref<8xf32, 1>, memref<1xi32>
+  return
+}
+
+// CHECK-LABEL: func @fold_dma_start_collapse_shape_src
+// CHECK-SAME:   %[[SRC:[a-zA-Z0-9_]+]]: memref<4x8xf32>
+// CHECK-SAME:   %[[DST:[a-zA-Z0-9_]+]]: memref<8xf32, 1>
+// CHECK-SAME:   %[[TAG:[a-zA-Z0-9_]+]]: memref<1xi32>
+// CHECK-SAME:   %[[IDX:[a-zA-Z0-9_]+]]: index
+//  CHECK-DAG:   %[[C0:.*]] = arith.constant 0
+//  CHECK-DAG:   %[[NUM:.*]] = arith.constant 8
+//      CHECK:   %[[IDXS:.*]]:2 = affine.delinearize_index %[[IDX]] into (4, 8)
+//      CHECK:   memref.dma_start %[[SRC]][%[[IDXS]]#0, %[[IDXS]]#1], %[[DST]][%[[C0]]], %[[NUM]], %[[TAG]][%[[C0]]]
+
+// -----
+
+func.func @fold_dma_start_collapse_shape_dst(
+    %src : memref<8xf32>, %dst : memref<4x8xf32, 1>, %tag : memref<1xi32>,
+    %idx : index) {
+  %c0 = arith.constant 0 : index
+  %num_elements = arith.constant 8 : index
+  %collapse = memref.collapse_shape %dst [[0, 1]] : memref<4x8xf32, 1> into memref<32xf32, 1>
+  memref.dma_start %src[%c0], %collapse[%idx], %num_elements, %tag[%c0] : memref<8xf32>, memref<32xf32, 1>, memref<1xi32>
+  return
+}
+
+// CHECK-LABEL: func @fold_dma_start_collapse_shape_dst
+// CHECK-SAME:   %[[SRC:[a-zA-Z0-9_]+]]: memref<8xf32>
+// CHECK-SAME:   %[[DST:[a-zA-Z0-9_]+]]: memref<4x8xf32, 1>
+// CHECK-SAME:   %[[TAG:[a-zA-Z0-9_]+]]: memref<1xi32>
+// CHECK-SAME:   %[[IDX:[a-zA-Z0-9_]+]]: index
+//  CHECK-DAG:   %[[C0:.*]] = arith.constant 0
+//  CHECK-DAG:   %[[NUM:.*]] = arith.constant 8
+//      CHECK:   %[[IDXS:.*]]:2 = affine.delinearize_index %[[IDX]] into (4, 8)
+//      CHECK:   memref.dma_start %[[SRC]][%[[C0]]], %[[DST]][%[[IDXS]]#0, %[[IDXS]]#1], %[[NUM]], %[[TAG]][%[[C0]]]
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index af068d8ca8e95..d3670fde08d81 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -956,7 +956,7 @@ func.func @bad_alloc_wrong_symbol_count() {
 func.func @load_invalid_memref_indexes() {
   %0 = memref.alloca() : memref<10xi32>
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{incorrect number of indices for load, expected 1 but got 2}}
+  // expected-error at +1 {{invalid number of indices for accessed memref, expected 1 but got 2}}
   %1 = memref.load %0[%c0, %c0] : memref<10xi32>
 }
 
@@ -1042,7 +1042,7 @@ func.func @illegal_num_offsets(%arg0 : memref<?x?x?xf32>, %arg1 : index, %arg2 :
 // -----
 
 func.func @atomic_rmw_idxs_rank_mismatch(%I: memref<16x10xf32>, %i : index, %val : f32) {
-  // expected-error at +1 {{expects the number of subscripts to be equal to memref rank}}
+  // expected-error at +1 {{invalid number of indices for accessed memref, expected 2 but got 1}}
   %x = memref.atomic_rmw addf %val, %I[%i] : (f32, memref<16x10xf32>) -> f32
   return
 }



More information about the Mlir-commits mailing list