[Mlir-commits] [mlir] [mlir][MemRef] Migrate memref dialect alias op folding to interface (PR #187168)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 17 18:15:22 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir-linalg

Author: Krzysztof Drewniak (krzysz00)

<details>
<summary>Changes</summary>

This PR adds code to FoldMemRefAliasOps / --fold-memref-alias-ops to use the new IndexedMemoryAccessOpInterface and
IndexedMemCopyOpInterface and implement those operations for relevant operations in the memref dialect.

This is a reordering of the changes planned in #<!-- -->177014 and #<!-- -->177016 to make them more testable.

There are no behavior changes expected for how memref.load and memref.store behave within the alias ops folding pass, though support for new operations, like memref.prefetch, has been added.

Some error messages have been updated because certain laws of memref.load/memref.store have been moved to IndexedAccessOpInterface.

Assisted-by: Claude 4.6 (helped deal with some of the boilerplate in the rewrite patterns and with extracting the patch)

---

Patch is 41.26 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/187168.diff


7 Files Affected:

- (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRef.h (+1) 
- (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+36-8) 
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+81-18) 
- (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (+326-25) 
- (modified) mlir/test/Dialect/Linalg/invalid.mlir (+2-2) 
- (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+145) 
- (modified) mlir/test/Dialect/MemRef/invalid.mlir (+2-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
index b7abcdea10a2a..8653eca0072b6 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/Interfaces/AlignmentAttrInterface.h"
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 70180c101407a..9dba4d790d631 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -11,6 +11,7 @@
 
 include "mlir/Dialect/Arith/IR/ArithBase.td"
 include "mlir/Dialect/MemRef/IR/MemRefBase.td"
+include "mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.td"
 include "mlir/Interfaces/AlignmentAttrInterface.td"
 include "mlir/Interfaces/CastInterfaces.td"
 include "mlir/Interfaces/ControlFlowInterfaces.td"
@@ -699,7 +700,8 @@ def MemRef_DimOp : MemRef_Op<"dim", [
 // DmaStartOp
 //===----------------------------------------------------------------------===//
 
-def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
+def MemRef_DmaStartOp : MemRef_Op<"dma_start", [
+    IndexedMemCopyOpInterface]> {
   let summary = "non-blocking DMA operation that starts a transfer";
   let description = [{
     Syntax:
@@ -778,6 +780,13 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
       return {(*this)->operand_begin() + 1,
               (*this)->operand_begin() + 1 + getSrcMemRefRank()};
     }
+    // Alias to getSrcMemRef() for uniformity with other DMA-like ops.
+    ::mlir::TypedValue<::mlir::MemRefType> getSrc() {
+      // This can be called before op verifaciton, so we guarad against bad variadics.
+      if ((*this)->getOperands().empty())
+        return nullptr;
+      return ::llvm::dyn_cast<::mlir::TypedValue<::mlir::MemRefType>>(getSrcMemRef());
+    }
 
     // Returns the destination MemRefType for this DMA operations.
     Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
@@ -786,6 +795,16 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
     unsigned getDstMemRefRank() {
       return ::llvm::cast<MemRefType>(getDstMemRef().getType()).getRank();
     }
+    // Alias to getDstMemRef() for uniformity with other DMA-like ops.
+    ::mlir::TypedValue<::mlir::MemRefType> getDst() {
+      // Guardrails since this runs before the op verifier and the DMA op doesn't use ODS to define operands.
+      if (!getSrc())
+        return nullptr;
+      if ((*this)->getNumOperands() < (1 + getSrcMemRefRank() + 1))
+        return nullptr;
+      return ::llvm::dyn_cast<::mlir::TypedValue<::mlir::MemRefType>>(getDstMemRef());
+    }
+
     unsigned getSrcMemorySpace() {
       return ::llvm::cast<MemRefType>(getSrcMemRef().getType()).getMemorySpaceAsInt();
     }
@@ -875,6 +894,10 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
       effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
                            SideEffects::DefaultResource::get());
     }
+
+    void setMemrefsAndIndices(RewriterBase& rewriter,
+      Value newSrc, ValueRange newSrcIndices,
+      Value newDst, ValueRange newDstIndices);
   }];
   let hasCustomAssemblyFormat = 1;
   let hasFolder = 1;
@@ -1066,7 +1089,8 @@ def GenericAtomicRMWOp : MemRef_Op<"generic_atomic_rmw", [
       SingleBlockImplicitTerminator<"AtomicYieldOp">,
       TypesMatchWith<"result type matches element type of memref",
                      "memref", "result",
-                     "::llvm::cast<MemRefType>($_self).getElementType()">
+                     "::llvm::cast<MemRefType>($_self).getElementType()">,
+      DeclareOpInterfaceMethods<IndexedAccessOpInterface, ["getAccessedMemref"]>,
     ]> {
   let summary = "atomic read-modify-write operation with a region";
   let description = [{
@@ -1243,7 +1267,8 @@ def LoadOp : MemRef_Op<"load",
       DeclareOpInterfaceMethods<AlignmentAttrOpInterface>,
       DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
       DeclareOpInterfaceMethods<PromotableMemOpInterface>,
-      DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
+      DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
+      DeclareOpInterfaceMethods<IndexedAccessOpInterface, ["getAccessedMemref"]>]> {
   let summary = "load operation";
   let description = [{
     The `load` op reads an element from a memref at the specified indices.
@@ -1321,7 +1346,6 @@ def LoadOp : MemRef_Op<"load",
   }];
 
   let hasFolder = 1;
-  let hasVerifier = 1;
 
   let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)";
 }
@@ -1404,7 +1428,10 @@ def MemRef_MemorySpaceCastOp : MemRef_Op<"memory_space_cast", [
 // PrefetchOp
 //===----------------------------------------------------------------------===//
 
-def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
+def MemRef_PrefetchOp : MemRef_Op<"prefetch", [
+      DeclareOpInterfaceMethods<IndexedAccessOpInterface,
+                                ["getAccessedMemref", "getAccessedType"]>
+    ]> {
   let summary = "prefetch operation";
   let description = [{
     The "prefetch" op prefetches data from a memref location described with
@@ -2021,7 +2048,8 @@ def MemRef_StoreOp : MemRef_Op<"store",
       DeclareOpInterfaceMethods<AlignmentAttrOpInterface>,
       DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
       DeclareOpInterfaceMethods<PromotableMemOpInterface>,
-      DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>]> {
+      DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
+      DeclareOpInterfaceMethods<IndexedAccessOpInterface, ["getAccessedMemref"]>]> {
   let summary = "store operation";
   let description = [{
     The `store` op stores an element into a memref at the specified indices.
@@ -2084,7 +2112,6 @@ def MemRef_StoreOp : MemRef_Op<"store",
   }];
 
   let hasFolder = 1;
-  let hasVerifier = 1;
 
   let assemblyFormat = [{
     $value `,` $memref `[` $indices `]` attr-dict `:` type($memref)
@@ -2494,7 +2521,8 @@ def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
       AllTypesMatch<["value", "result"]>,
       TypesMatchWith<"value type matches element type of memref",
                      "memref", "value",
-                     "::llvm::cast<MemRefType>($_self).getElementType()">
+                     "::llvm::cast<MemRefType>($_self).getElementType()">,
+      DeclareOpInterfaceMethods<IndexedAccessOpInterface, ["getAccessedMemref"]>
     ]> {
   let summary = "atomic read-modify-write operation";
   let description = [{
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 404b2aacf1450..b6d18dd28af26 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1407,6 +1407,26 @@ LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
   return foldMemRefCast(*this);
 }
 
+void DmaStartOp::setMemrefsAndIndices(RewriterBase &rewriter, Value newSrc,
+                                      ValueRange newSrcIndices, Value newDst,
+                                      ValueRange newDstIndices) {
+  /// dma_start has special handling for variadic rank
+  SmallVector<Value> newOperands;
+  newOperands.push_back(newSrc);
+  llvm::append_range(newOperands, newSrcIndices);
+  newOperands.push_back(newDst);
+  llvm::append_range(newOperands, newDstIndices);
+  newOperands.push_back(getNumElements());
+  newOperands.push_back(getTagMemRef());
+  llvm::append_range(newOperands, getTagIndices());
+  if (isStrided()) {
+    newOperands.push_back(getStride());
+    newOperands.push_back(getNumElementsPerStride());
+  }
+
+  rewriter.modifyOpInPlace(*this, [&]() { (*this)->setOperands(newOperands); });
+}
+
 // ---------------------------------------------------------------------------
 // DmaWaitOp
 // ---------------------------------------------------------------------------
@@ -1635,6 +1655,19 @@ void GenericAtomicRMWOp::print(OpAsmPrinter &p) {
   p.printOptionalAttrDict((*this)->getAttrs());
 }
 
+TypedValue<MemRefType> GenericAtomicRMWOp::getAccessedMemref() {
+  return getMemref();
+}
+
+std::optional<SmallVector<Value>> GenericAtomicRMWOp::updateMemrefAndIndices(
+    RewriterBase &rewriter, Value newMemref, ValueRange newIndices) {
+  rewriter.modifyOpInPlace(*this, [&]() {
+    getMemrefMutable().assign(newMemref);
+    getIndicesMutable().assign(newIndices);
+  });
+  return std::nullopt;
+}
+
 //===----------------------------------------------------------------------===//
 // AtomicYieldOp
 //===----------------------------------------------------------------------===//
@@ -1770,14 +1803,6 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 // LoadOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult LoadOp::verify() {
-  if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
-    return emitOpError("incorrect number of indices for load, expected ")
-           << getMemRefType().getRank() << " but got " << getIndices().size();
-  }
-  return success();
-}
-
 OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
   /// load(memrefcast) -> load
   if (succeeded(foldMemRefCast(*this)))
@@ -1802,6 +1827,18 @@ OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
   return splatAttr.getSplatValue<Attribute>();
 }
 
+TypedValue<MemRefType> LoadOp::getAccessedMemref() { return getMemref(); }
+
+std::optional<SmallVector<Value>>
+LoadOp::updateMemrefAndIndices(RewriterBase &rewriter, Value newMemref,
+                               ValueRange newIndices) {
+  rewriter.modifyOpInPlace(*this, [&]() {
+    getMemrefMutable().assign(newMemref);
+    getIndicesMutable().assign(newIndices);
+  });
+  return std::nullopt;
+}
+
 FailureOr<std::optional<SmallVector<Value>>>
 LoadOp::bubbleDownCasts(OpBuilder &builder) {
   return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
@@ -1945,6 +1982,18 @@ LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
   return foldMemRefCast(*this);
 }
 
+TypedValue<MemRefType> PrefetchOp::getAccessedMemref() { return getMemref(); }
+
+std::optional<SmallVector<Value>>
+PrefetchOp::updateMemrefAndIndices(RewriterBase &rewriter, Value newMemref,
+                                   ValueRange newIndices) {
+  rewriter.modifyOpInPlace(*this, [&]() {
+    getMemrefMutable().assign(newMemref);
+    getIndicesMutable().assign(newIndices);
+  });
+  return std::nullopt;
+}
+
 //===----------------------------------------------------------------------===//
 // RankOp
 //===----------------------------------------------------------------------===//
@@ -2970,19 +3019,24 @@ ReshapeOp::bubbleDownCasts(OpBuilder &builder) {
 // StoreOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult StoreOp::verify() {
-  if (getNumOperands() != 2 + getMemRefType().getRank())
-    return emitOpError("store index operand count not equal to memref rank");
-
-  return success();
-}
-
 LogicalResult StoreOp::fold(FoldAdaptor adaptor,
                             SmallVectorImpl<OpFoldResult> &results) {
   /// store(memrefcast) -> store
   return foldMemRefCast(*this, getValueToStore());
 }
 
+TypedValue<MemRefType> StoreOp::getAccessedMemref() { return getMemref(); }
+
+std::optional<SmallVector<Value>>
+StoreOp::updateMemrefAndIndices(RewriterBase &rewriter, Value newMemref,
+                                ValueRange newIndices) {
+  rewriter.modifyOpInPlace(*this, [&]() {
+    getMemrefMutable().assign(newMemref);
+    getIndicesMutable().assign(newIndices);
+  });
+  return std::nullopt;
+}
+
 FailureOr<std::optional<SmallVector<Value>>>
 StoreOp::bubbleDownCasts(OpBuilder &builder) {
   return mlir::detail::bubbleDownInPlaceMemorySpaceCastImpl(getMemrefMutable(),
@@ -3971,9 +4025,6 @@ ViewOp::bubbleDownCasts(OpBuilder &builder) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult AtomicRMWOp::verify() {
-  if (getMemRefType().getRank() != getNumOperands() - 2)
-    return emitOpError(
-        "expects the number of subscripts to be equal to memref rank");
   switch (getKind()) {
   case arith::AtomicRMWKind::addf:
   case arith::AtomicRMWKind::maximumf:
@@ -4017,6 +4068,18 @@ AtomicRMWOp::bubbleDownCasts(OpBuilder &builder) {
                                                             getResult());
 }
 
+TypedValue<MemRefType> AtomicRMWOp::getAccessedMemref() { return getMemref(); }
+
+std::optional<SmallVector<Value>>
+AtomicRMWOp::updateMemrefAndIndices(RewriterBase &rewriter, Value newMemref,
+                                    ValueRange newIndices) {
+  rewriter.modifyOpInPlace(*this, [&]() {
+    getMemrefMutable().assign(newMemref);
+    getIndicesMutable().assign(newIndices);
+  });
+  return std::nullopt;
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 06c3392cd6732..9fbb8f2d5892a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -15,17 +15,21 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.h"
 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
+#include <cstdint>
 
 #define DEBUG_TYPE "fold-memref-alias-ops"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -43,6 +47,25 @@ using namespace mlir;
 // Utility functions
 //===----------------------------------------------------------------------===//
 
+/// Deterimine if the last N indices of `reassocitaion` are trivial - that is,
+/// check if they all contain exactly one dimension to collape/expand into.
+static bool
+hasTrivialReassociationSuffix(ArrayRef<ReassociationIndices> reassocs,
+                              int64_t n) {
+  if (n <= 0)
+    return true;
+  return llvm::all_of(
+      reassocs.take_back(n),
+      [&](const ReassociationIndices &indices) { return indices.size() == 1; });
+}
+
+static bool hasTrailingUnitStrides(memref::SubViewOp subview, int64_t n) {
+  if (n <= 0)
+    return true;
+  return llvm::all_of(subview.getStaticStrides().take_back(n),
+                      [](int64_t s) { return s == 1; });
+}
+
 /// Helpers to access the memref operand for each op.
 template <typename LoadOrStoreOpTy>
 static Value getMemRefOperand(LoadOrStoreOpTy op) {
@@ -195,6 +218,82 @@ class NVGPUAsyncCopyOpSubViewOpFolder final
   LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
                                 PatternRewriter &rewriter) const override;
 };
+
+/// Merges subview operations with load/store like operations unless such a
+/// merger would cause the strides between dimensions accessed by that operaton
+/// to change.
+struct AccessOpOfSubViewOpFolder final
+    : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Merge a memref.expand_shape operation with an operation that accesses a
+/// memref by index unless that operation accesss more than one dimension of
+/// memory and any dimension other than the outermost dimension accessed this
+/// way would be merged. This prevents issuses from arising with, say, a
+/// vector.load of a 4x2 vector having the two trailing dimensions of the access
+/// get merged.
+struct AccessOpOfExpandShapeOpFolder final
+    : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Merges an operation that accesses a memref by index with a
+/// memref.collapse_shape, unless this would break apart a dimension other than
+/// the outermost one that an operation accesses. This prevents, for example,
+/// transforming a load of a 3x8 vector from a 6x8 memref into a load
+/// from a 3x4x2 memref (as this would require special handling and could lead
+/// to invalid IR if that higher-dimensional memref comes from a subview) but
+/// does permit turning a load of a length-8 vector from a 3x8 memref into a
+/// load from a 3x2x8 one.
+struct AccessOpOfCollapseShapeOpFolder final
+    : OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Merges memref.subview operations present on the source or destination
+/// operands of indexed memory copy operations (DMA operations) into those
+/// operations. This is perfromed unconditionally, since folding in a subview
+/// cannot change the starting position of the copy, which is what the
+/// memref/index pair represent in DMA operations.
+struct IndexedMemCopyOpOfSubViewOpFolder final
+    : OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Merges memref.expand_shape operations that are present on the source or
+/// destination of an indexed memory copy/DMA into the memref/index arguments of
+/// that DMA. As with subviews, this can be done unconditionally.
+struct IndexedMemCopyOpOfExpandShapeOpFolder final
+    : OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Merges memref.collapse_shape operations that are present on the source or
+/// destination of an indexed memory copy/DMA into the memref/index arguments of
+/// that DMA. As with subviews, this can be done unconditionally.
+struct IndexedMemCopyOpOfCollapseShapeOpFolder final
+    : OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
+                                PatternRewriter &rewriter) const override;
+};
 } // namespace
 
 template <typename XferOp>
@@ -516,6 +615,207 @@ LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
   return success();
 }
 
+LogicalResult
+AccessOpOfSubViewOpFolder::matchAndRewrite(memref::IndexedAccessOpInterface op,
+                                           PatternRewriter &rewriter) const {
+  auto subview = op.getAccessedMemref().getDefiningOp<memref::SubViewOp>();
+  if (!subview)
+    return rewriter.notifyMatchFailure(op, "not accessing a subview");
+
+  SmallVector<int64_t> accessedShape = op.getAccessedShape();
+  // Note the subtle difference between accesedShape = {1} and accessedShape =
+  // {} here. The former prevents us from fdolding in a subview that doesn't
+  // have a unit stride on the final dimension, while the latter does not (since
+  // it indices scalar accesss).
+  int64_t accessedDims = accessedShape.size();
+  if (!hasTrailingUnitStrides(subview, accessedDims))
+    return rewriter.notifyMatchFailure(
+        op, "non-unit stride on accessed dimensions");
+
+  llvm::SmallBitVector droppedDims ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/187168


More information about the Mlir-commits mailing list