[Mlir-commits] [mlir] 14a290a - [mlir][vector] Implement IndexedAccessOpInterface for load, store, etc. (#196216)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 8 08:39:10 PDT 2026


Author: Krzysztof Drewniak
Date: 2026-05-08T15:39:05Z
New Revision: 14a290ab3d013fa30e8efd767beb4c4f002c6175

URL: https://github.com/llvm/llvm-project/commit/14a290ab3d013fa30e8efd767beb4c4f002c6175
DIFF: https://github.com/llvm/llvm-project/commit/14a290ab3d013fa30e8efd767beb4c4f002c6175.diff

LOG: [mlir][vector] Implement IndexedAccessOpInterface for load, store, etc. (#196216)

This commit adds simple (not trying to account for unit dimensions that
could be cast away) implementations of IndexedAccessOpInterface to
low-level vector operations like vector.load and vector.store,
eliminating the need for the old-style code in FoldMemRefAliasOps.cpp.
After this commit, it'll be possible to migrate all the other
memref-rewriting passes (ExpandAddressComputation and FlattenMemRefs) to
use the interface, taking a bunch of dialect dependencies off of
memref/transforms.

Assisted-By: GPT 5.5 (pulled in old code, wrote some new tests)

Added: 
    mlir/include/mlir/Dialect/Vector/Transforms/IndexedAccessOpInterfaceImpl.h
    mlir/lib/Dialect/Vector/Transforms/IndexedAccessOpInterfaceImpl.cpp

Modified: 
    mlir/include/mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.td
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
    mlir/lib/RegisterAllDialects.cpp
    mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.td b/mlir/include/mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.td
index 7fc69b4fabca6..0f1ef521afc57 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.td
@@ -53,7 +53,11 @@ def IndexedAccessOpInterface : OpInterface<"IndexedAccessOpInterface"> {
     InterfaceMethod<
       /*desc=*/[{
         Return the shape of the portion of the memref that is being accessed by
-        this operation, if known, ignoring leading unit dimensions.
+        this operation, if known. This shape describes the access dimensions
+        whose strides are semantically important for this operation.
+        Implementations shall omit dimensions whose strides do not affect the
+        operation semantics. (In particular, if an operation will access one
+        element of the base memref, this method should return `{}`.)
 
         Reindexing transformations may not modify the *strides* of the trailing
         N dimensions, where N is the size returned value, and should ensure that

diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 8f4fa5ca6a844..28a8109cb59c0 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1658,6 +1658,7 @@ def Vector_TransferWriteOp :
   let hasVerifier = 1;
 }
 
+// Promises IndexedAccessOpInterface.
 def Vector_LoadOp : Vector_Op<"load", [
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
     DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
@@ -1776,6 +1777,7 @@ def Vector_LoadOp : Vector_Op<"load", [
       "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
 }
 
+// Promises IndexedAccessOpInterface.
 def Vector_StoreOp : Vector_Op<"store", [
     DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
     DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
@@ -1883,6 +1885,7 @@ def Vector_StoreOp : Vector_Op<"store", [
                        "`:` type($base) `,` type($valueToStore)";
 }
 
+// Promises IndexedAccessOpInterface.
 def Vector_MaskedLoadOp :
   Vector_Op<"maskedload", [
       DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
@@ -1978,6 +1981,7 @@ def Vector_MaskedLoadOp :
   ];
 }
 
+// Promises IndexedAccessOpInterface.
 def Vector_MaskedStoreOp :
   Vector_Op<"maskedstore", [
       DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
@@ -2251,6 +2255,7 @@ def Vector_ScatterOp
     }]>];
 }
 
+// Promises IndexedAccessOpInterface.
 def Vector_ExpandLoadOp :
   Vector_Op<"expandload", [
       DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
@@ -2342,6 +2347,7 @@ def Vector_ExpandLoadOp :
   ];
 }
 
+// Promises IndexedAccessOpInterface.
 def Vector_CompressStoreOp :
   Vector_Op<"compressstore", [
       DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,

diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/IndexedAccessOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Vector/Transforms/IndexedAccessOpInterfaceImpl.h
new file mode 100644
index 0000000000000..57fe661ad81f0
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/IndexedAccessOpInterfaceImpl.h
@@ -0,0 +1,21 @@
+//===- IndexedAccessOpInterfaceImpl.h ---------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_INDEXEDACCESSOPINTERFACEIMPL_H
+#define MLIR_DIALECT_VECTOR_TRANSFORMS_INDEXEDACCESSOPINTERFACEIMPL_H
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace vector {
+void registerIndexedAccessOpInterfaceExternalModels(DialectRegistry &registry);
+} // namespace vector
+} // namespace mlir
+
+#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_INDEXEDACCESSOPINTERFACEIMPL_H

diff  --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index e36ddfa063e11..de7662753d142 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -25,7 +25,6 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
-#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include <cstdint>
 
@@ -52,6 +51,8 @@ hasTrivialReassociationSuffix(ArrayRef<ReassociationIndices> reassocs,
                               int64_t n) {
   if (n <= 0)
     return true;
+  if (n > static_cast<int64_t>(reassocs.size()))
+    return false;
   return llvm::all_of(
       reassocs.take_back(n),
       [&](const ReassociationIndices &indices) { return indices.size() == 1; });
@@ -60,89 +61,17 @@ hasTrivialReassociationSuffix(ArrayRef<ReassociationIndices> reassocs,
 static bool hasTrailingUnitStrides(memref::SubViewOp subview, int64_t n) {
   if (n <= 0)
     return true;
-  return llvm::all_of(subview.getStaticStrides().take_back(n),
-                      [](int64_t s) { return s == 1; });
+  ArrayRef<int64_t> strides = subview.getStaticStrides();
+  if (n > static_cast<int64_t>(strides.size()))
+    return false;
+  return llvm::all_of(strides.take_back(n), [](int64_t s) { return s == 1; });
 }
 
-/// Helpers to access the memref operand for each op.
-template <typename LoadOrStoreOpTy>
-static Value getMemRefOperand(LoadOrStoreOpTy op) {
-  return op.getMemref();
-}
-
-static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
-
-static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); }
-
-static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
-
-static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); }
-
 //===----------------------------------------------------------------------===//
 // Patterns
 //===----------------------------------------------------------------------===//
 
 namespace {
-/// Merges subview operation with load/transferRead operation.
-template <typename OpTy>
-class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
-public:
-  using OpRewritePattern<OpTy>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(OpTy loadOp,
-                                PatternRewriter &rewriter) const override;
-};
-
-/// Merges expand_shape operation with load/transferRead operation.
-template <typename OpTy>
-class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
-public:
-  using OpRewritePattern<OpTy>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(OpTy loadOp,
-                                PatternRewriter &rewriter) const override;
-};
-
-/// Merges collapse_shape operation with load/transferRead operation.
-template <typename OpTy>
-class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
-public:
-  using OpRewritePattern<OpTy>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(OpTy loadOp,
-                                PatternRewriter &rewriter) const override;
-};
-
-/// Merges subview operation with store/transferWriteOp operation.
-template <typename OpTy>
-class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
-public:
-  using OpRewritePattern<OpTy>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(OpTy storeOp,
-                                PatternRewriter &rewriter) const override;
-};
-
-/// Merges expand_shape operation with store/transferWriteOp operation.
-template <typename OpTy>
-class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
-public:
-  using OpRewritePattern<OpTy>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(OpTy storeOp,
-                                PatternRewriter &rewriter) const override;
-};
-
-/// Merges collapse_shape operation with store/transferWriteOp operation.
-template <typename OpTy>
-class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
-public:
-  using OpRewritePattern<OpTy>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(OpTy storeOp,
-                                PatternRewriter &rewriter) const override;
-};
-
 /// Folds subview(subview(x)) to a single subview(x).
 class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
 public:
@@ -286,226 +215,6 @@ struct TransferOpOfCollapseShapeOpFolder final
 };
 } // namespace
 
-static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
-                                                Operation *op,
-                                                memref::SubViewOp subviewOp) {
-  return success();
-}
-
-template <typename OpTy>
-LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
-    OpTy loadOp, PatternRewriter &rewriter) const {
-  auto subViewOp =
-      getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
-
-  if (!subViewOp)
-    return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
-
-  LogicalResult preconditionResult =
-      preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp);
-  if (failed(preconditionResult))
-    return preconditionResult;
-
-  SmallVector<Value> sourceIndices;
-  affine::resolveIndicesIntoOpWithOffsetsAndStrides(
-      rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
-      subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
-      loadOp.getIndices(), sourceIndices);
-
-  llvm::TypeSwitch<Operation *, void>(loadOp)
-      .Case([&](memref::LoadOp op) {
-        rewriter.replaceOpWithNewOp<memref::LoadOp>(
-            loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
-      })
-      .Case([&](vector::LoadOp op) {
-        rewriter.replaceOpWithNewOp<vector::LoadOp>(
-            op, op.getType(), subViewOp.getSource(), sourceIndices);
-      })
-      .Case([&](vector::MaskedLoadOp op) {
-        rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
-            op, op.getType(), subViewOp.getSource(), sourceIndices,
-            op.getMask(), op.getPassThru());
-      })
-      .DefaultUnreachable("unexpected operation");
-  return success();
-}
-
-template <typename OpTy>
-LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
-    OpTy loadOp, PatternRewriter &rewriter) const {
-  auto expandShapeOp =
-      getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>();
-
-  if (!expandShapeOp)
-    return failure();
-
-  SmallVector<Value> sourceIndices;
-  // memref.load guarantees that indexes start inbounds while the vector
-  // operations don't. This impacts if our linearization is `disjoint`
-  resolveSourceIndicesExpandShape(loadOp.getLoc(), rewriter, expandShapeOp,
-                                  loadOp.getIndices(), sourceIndices,
-                                  isa<memref::LoadOp>(loadOp.getOperation()));
-
-  return llvm::TypeSwitch<Operation *, LogicalResult>(loadOp)
-      .Case([&](memref::LoadOp op) {
-        rewriter.replaceOpWithNewOp<memref::LoadOp>(
-            loadOp, expandShapeOp.getViewSource(), sourceIndices,
-            op.getNontemporal());
-        return success();
-      })
-      .Case([&](vector::LoadOp op) {
-        rewriter.replaceOpWithNewOp<vector::LoadOp>(
-            op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
-            op.getNontemporal());
-        return success();
-      })
-      .Case([&](vector::MaskedLoadOp op) {
-        rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
-            op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
-            op.getMask(), op.getPassThru());
-        return success();
-      })
-      .DefaultUnreachable("unexpected operation");
-}
-
-template <typename OpTy>
-LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
-    OpTy loadOp, PatternRewriter &rewriter) const {
-  auto collapseShapeOp = getMemRefOperand(loadOp)
-                             .template getDefiningOp<memref::CollapseShapeOp>();
-
-  if (!collapseShapeOp)
-    return failure();
-
-  SmallVector<Value> sourceIndices;
-  resolveSourceIndicesCollapseShape(loadOp.getLoc(), rewriter, collapseShapeOp,
-                                    loadOp.getIndices(), sourceIndices);
-  llvm::TypeSwitch<Operation *, void>(loadOp)
-      .Case([&](memref::LoadOp op) {
-        rewriter.replaceOpWithNewOp<memref::LoadOp>(
-            loadOp, collapseShapeOp.getViewSource(), sourceIndices,
-            op.getNontemporal());
-      })
-      .Case([&](vector::LoadOp op) {
-        rewriter.replaceOpWithNewOp<vector::LoadOp>(
-            op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
-            op.getNontemporal());
-      })
-      .Case([&](vector::MaskedLoadOp op) {
-        rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
-            op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
-            op.getMask(), op.getPassThru());
-      })
-      .DefaultUnreachable("unexpected operation");
-  return success();
-}
-
-template <typename OpTy>
-LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
-    OpTy storeOp, PatternRewriter &rewriter) const {
-  auto subViewOp =
-      getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
-
-  if (!subViewOp)
-    return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
-
-  LogicalResult preconditionResult =
-      preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp);
-  if (failed(preconditionResult))
-    return preconditionResult;
-
-  SmallVector<Value> sourceIndices;
-  affine::resolveIndicesIntoOpWithOffsetsAndStrides(
-      rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
-      subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
-      storeOp.getIndices(), sourceIndices);
-
-  llvm::TypeSwitch<Operation *, void>(storeOp)
-      .Case([&](memref::StoreOp op) {
-        rewriter.replaceOpWithNewOp<memref::StoreOp>(
-            op, op.getValue(), subViewOp.getSource(), sourceIndices,
-            op.getNontemporal());
-      })
-      .Case([&](vector::StoreOp op) {
-        rewriter.replaceOpWithNewOp<vector::StoreOp>(
-            op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
-      })
-      .Case([&](vector::MaskedStoreOp op) {
-        rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
-            op, subViewOp.getSource(), sourceIndices, op.getMask(),
-            op.getValueToStore());
-      })
-      .DefaultUnreachable("unexpected operation");
-  return success();
-}
-
-template <typename OpTy>
-LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
-    OpTy storeOp, PatternRewriter &rewriter) const {
-  auto expandShapeOp =
-      getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>();
-
-  if (!expandShapeOp)
-    return failure();
-
-  SmallVector<Value> sourceIndices;
-  // memref.store guarantees that indexes start inbounds while the vector
-  // operations don't. This impacts if our linearization is `disjoint`
-  resolveSourceIndicesExpandShape(storeOp.getLoc(), rewriter, expandShapeOp,
-                                  storeOp.getIndices(), sourceIndices,
-                                  isa<memref::StoreOp>(storeOp.getOperation()));
-  llvm::TypeSwitch<Operation *, void>(storeOp)
-      .Case([&](memref::StoreOp op) {
-        rewriter.replaceOpWithNewOp<memref::StoreOp>(
-            storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
-            sourceIndices, op.getNontemporal());
-      })
-      .Case([&](vector::StoreOp op) {
-        rewriter.replaceOpWithNewOp<vector::StoreOp>(
-            op, op.getValueToStore(), expandShapeOp.getViewSource(),
-            sourceIndices, op.getNontemporal());
-      })
-      .Case([&](vector::MaskedStoreOp op) {
-        rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
-            op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
-            op.getValueToStore());
-      })
-      .DefaultUnreachable("unexpected operation");
-  return success();
-}
-
-template <typename OpTy>
-LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
-    OpTy storeOp, PatternRewriter &rewriter) const {
-  auto collapseShapeOp = getMemRefOperand(storeOp)
-                             .template getDefiningOp<memref::CollapseShapeOp>();
-
-  if (!collapseShapeOp)
-    return failure();
-
-  SmallVector<Value> sourceIndices;
-  resolveSourceIndicesCollapseShape(storeOp.getLoc(), rewriter, collapseShapeOp,
-                                    storeOp.getIndices(), sourceIndices);
-  llvm::TypeSwitch<Operation *, void>(storeOp)
-      .Case([&](memref::StoreOp op) {
-        rewriter.replaceOpWithNewOp<memref::StoreOp>(
-            storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
-            sourceIndices, op.getNontemporal());
-      })
-      .Case([&](vector::StoreOp op) {
-        rewriter.replaceOpWithNewOp<vector::StoreOp>(
-            op, op.getValueToStore(), collapseShapeOp.getViewSource(),
-            sourceIndices, op.getNontemporal());
-      })
-      .Case([&](vector::MaskedStoreOp op) {
-        rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
-            op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
-            op.getValueToStore());
-      })
-      .DefaultUnreachable("unexpected operation");
-  return success();
-}
-
 LogicalResult
 AccessOpOfSubViewOpFolder::matchAndRewrite(memref::IndexedAccessOpInterface op,
                                            PatternRewriter &rewriter) const {
@@ -849,27 +558,13 @@ LogicalResult TransferOpOfCollapseShapeOpFolder::matchAndRewrite(
 }
 
 void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
-  patterns.add<
-      // Interface-based patterns to which we will be migrating.
-      AccessOpOfSubViewOpFolder, AccessOpOfExpandShapeOpFolder,
-      AccessOpOfCollapseShapeOpFolder, IndexedMemCopyOpOfSubViewOpFolder,
-      IndexedMemCopyOpOfExpandShapeOpFolder,
-      IndexedMemCopyOpOfCollapseShapeOpFolder, TransferOpOfSubViewOpFolder,
-      TransferOpOfExpandShapeOpFolder, TransferOpOfCollapseShapeOpFolder,
-      // The old way of doing things. Don't add more of these.
-      LoadOpOfSubViewOpFolder<vector::LoadOp>,
-      LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
-      StoreOpOfSubViewOpFolder<vector::StoreOp>,
-      StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
-      LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
-      LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
-      StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
-      StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
-      LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
-      LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
-      StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
-      StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
-      SubViewOfSubViewFolder>(patterns.getContext());
+  patterns
+      .add<AccessOpOfSubViewOpFolder, AccessOpOfExpandShapeOpFolder,
+           AccessOpOfCollapseShapeOpFolder, IndexedMemCopyOpOfSubViewOpFolder,
+           IndexedMemCopyOpOfExpandShapeOpFolder,
+           IndexedMemCopyOpOfCollapseShapeOpFolder, TransferOpOfSubViewOpFolder,
+           TransferOpOfExpandShapeOpFolder, TransferOpOfCollapseShapeOpFolder,
+           SubViewOfSubViewFolder>(patterns.getContext());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 780b0cbb36120..51be1e4431e70 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/UB/IR/UBMatchers.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -483,6 +484,9 @@ void VectorDialect::initialize() {
 
   addInterfaces<VectorInlinerInterface>();
 
+  declarePromisedInterfaces<memref::IndexedAccessOpInterface, LoadOp, StoreOp,
+                            MaskedLoadOp, MaskedStoreOp, ExpandLoadOp,
+                            CompressStoreOp>();
   declarePromisedInterfaces<bufferization::BufferizableOpInterface,
                             TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
                             YieldOp>();

diff  --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 4e0f07af95984..112a1db6fe93b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRVectorTransforms
   BufferizableOpInterfaceImpl.cpp
+  IndexedAccessOpInterfaceImpl.cpp
   LowerVectorBitCast.cpp
   LowerVectorBroadcast.cpp
   LowerVectorContract.cpp

diff  --git a/mlir/lib/Dialect/Vector/Transforms/IndexedAccessOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/IndexedAccessOpInterfaceImpl.cpp
new file mode 100644
index 0000000000000..c91ea97a2f965
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/IndexedAccessOpInterfaceImpl.cpp
@@ -0,0 +1,101 @@
+//===- IndexedAccessOpInterfaceImpl.cpp -----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Implement IndexedAccessOpInterface on vector dialect operations with
+// %memref[%i, %j, ...] operands so generic memref-dialect passes can rewrite
+// their base/index pairs. Transfer ops keep their VectorTransferOpInterface
+// patterns; gather/scatter have tensor-or-memref bases and index-vector
+// operands that do not fit IndexedAccessOpInterface's rank-matched index
+// contract.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/Transforms/IndexedAccessOpInterfaceImpl.h"
+
+#include "mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::memref;
+
+namespace {
+/// Return true if this op has the memref semantics expected by this model.
+template <typename LoadStoreOp>
+bool hasMemrefSemantics(Operation *op) {
+  return llvm::isa<MemRefType>(cast<LoadStoreOp>(op).getBase().getType());
+}
+
+/// Return the vector shape whose access strides must be preserved, marking
+/// scalable dimensions as dynamic.
+SmallVector<int64_t> getAccessedVectorShape(VectorType vecTy) {
+  return llvm::map_to_vector(
+      llvm::zip_equal(vecTy.getShape(), vecTy.getScalableDims()), [](auto dim) {
+        auto [size, scalable] = dim;
+        return scalable ? ShapedType::kDynamic : size;
+      });
+}
+
+template <typename LoadStoreOp>
+struct VectorLoadStoreLikeOpImpl final
+    : IndexedAccessOpInterface::ExternalModel<
+          VectorLoadStoreLikeOpImpl<LoadStoreOp>, LoadStoreOp> {
+  TypedValue<MemRefType> getAccessedMemref(Operation *op) const {
+    return cast<LoadStoreOp>(op).getBase();
+  }
+
+  Operation::operand_range getIndices(Operation *op) const {
+    return cast<LoadStoreOp>(op).getIndices();
+  }
+
+  SmallVector<int64_t> getAccessedShape(Operation *op) const {
+    assert(hasMemrefSemantics<LoadStoreOp>(op) &&
+           "expected vector op with memref semantics");
+    return getAccessedVectorShape(cast<LoadStoreOp>(op).getVectorType());
+  }
+
+  std::optional<SmallVector<Value>>
+  updateMemrefAndIndices(Operation *op, RewriterBase &rewriter, Value newMemref,
+                         ValueRange newIndices) const {
+    assert(hasMemrefSemantics<LoadStoreOp>(op) &&
+           "expected vector op with memref semantics");
+    assert(llvm::isa<MemRefType>(newMemref.getType()) &&
+           "expected replacement memref");
+    rewriter.modifyOpInPlace(op, [&]() {
+      auto concreteOp = cast<LoadStoreOp>(op);
+      concreteOp.getBaseMutable().assign(newMemref);
+      concreteOp.getIndicesMutable().assign(newIndices);
+    });
+    return std::nullopt;
+  }
+
+  // TODO: The various load and store operations, at the very least vector.load
+  // and vector.store, should be taught a starts-in-bounds attribute that would
+  // let us optimize index generation.
+  bool hasInboundsIndices(Operation *op) const {
+    assert(hasMemrefSemantics<LoadStoreOp>(op) &&
+           "expected vector op with memref semantics");
+    return false;
+  }
+};
+
+template <typename... Ops>
+static void attachAll(MLIRContext *ctx) {
+  (Ops::template attachInterface<VectorLoadStoreLikeOpImpl<Ops>>(*ctx), ...);
+}
+
+} // namespace
+
+void mlir::vector::registerIndexedAccessOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
+    attachAll<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
+              vector::MaskedStoreOp, vector::ExpandLoadOp,
+              vector::CompressStoreOp>(ctx);
+  });
+}

diff  --git a/mlir/lib/RegisterAllDialects.cpp b/mlir/lib/RegisterAllDialects.cpp
index 589730b785133..01a7401db4710 100644
--- a/mlir/lib/RegisterAllDialects.cpp
+++ b/mlir/lib/RegisterAllDialects.cpp
@@ -95,6 +95,7 @@
 #include "mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Vector/Transforms/IndexedAccessOpInterfaceImpl.h"
 #include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
 #include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
 #include "mlir/Dialect/X86/X86Dialect.h"
@@ -197,6 +198,7 @@ void mlir::registerAllDialects(DialectRegistry &registry) {
   tensor::registerValueBoundsOpInterfaceExternalModels(registry);
   tosa::registerShardingInterfaceExternalModels(registry);
   vector::registerBufferizableOpInterfaceExternalModels(registry);
+  vector::registerIndexedAccessOpInterfaceExternalModels(registry);
   vector::registerSubsetOpInterfaceExternalModels(registry);
   vector::registerValueBoundsOpInterfaceExternalModels(registry);
   NVVM::registerNVVMTargetInterfaceExternalModels(registry);

diff  --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 50c7ebaff1e6a..6e2702d936ee0 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -642,6 +642,26 @@ func.func @fold_vector_load_subview(%src : memref<24x64xf32>,
 
 // -----
 
+// TODO: This should fold, but implementing IndexedAccessOpInterface on vector.load
+// in a way that would allow the fold added complexity (emitting
+// `vector.shape_cast`s) that people wanted to keep out of the initial
+// implementation during previous discussions. (Note: this didn't work in the
+// pre-interface version of the pass either.)
+func.func @no_fold_scalar_equivalent_vector_load_subview(
+  %arg0 : memref<16xf32>, %off : index, %idx : index) -> vector<1xf32> {
+  %0 = memref.subview %arg0[%off][4][2] : memref<16xf32> to memref<4xf32, strided<[2], offset: ?>>
+  %1 = vector.load %0[%idx] : memref<4xf32, strided<[2], offset: ?>>, vector<1xf32>
+  return %1 : vector<1xf32>
+}
+
+// CHECK-LABEL: func @no_fold_scalar_equivalent_vector_load_subview
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<16xf32>
+//       CHECK:   %[[SUBVIEW:.*]] = memref.subview %[[ARG0]]
+//   CHECK-NOT:   vector.shape_cast
+//       CHECK:   vector.load %[[SUBVIEW]]
+
+// -----
+
 func.func @fold_vector_maskedload_subview(
   %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> vector<32xf32> {
   %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
@@ -659,6 +679,21 @@ func.func @fold_vector_maskedload_subview(
 
 // -----
 
+func.func @no_fold_vector_maskedload_subview_high_rank_vector(
+  %arg0 : memref<8xf32>, %idx : index,
+  %mask : vector<2x2x2xi1>, %pass : vector<2x2x2xf32>) -> vector<2x2x2xf32> {
+  %0 = memref.subview %arg0[%idx][1][1] : memref<8xf32> to memref<1xf32, strided<[1], offset: ?>>
+  %1 = vector.maskedload %0[%idx], %mask, %pass : memref<1xf32, strided<[1], offset: ?>>, vector<2x2x2xi1>, vector<2x2x2xf32> into vector<2x2x2xf32>
+  return %1 : vector<2x2x2xf32>
+}
+
+// CHECK-LABEL: func @no_fold_vector_maskedload_subview_high_rank_vector
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<8xf32>
+//       CHECK:   %[[SUBVIEW:.*]] = memref.subview %[[ARG0]]
+//       CHECK:   vector.maskedload %[[SUBVIEW]]
+
+// -----
+
 func.func @fold_vector_store_subview(%src : memref<24x64xf32>,
                                      %off1 : index,
                                      %off2 : index,
@@ -723,6 +758,24 @@ func.func @fold_vector_load_expand_shape(
 
 // -----
 
+// Folding this would require changing the vector op rank. That is handled by
+// vector drop-leading-unit-dim patterns, not by fold-memref-alias-ops.
+func.func @no_fold_vector_load_expand_shape_leading_unit(
+  %arg0 : memref<32xf32>, %arg1 : index) -> vector<1x8xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+  %1 = vector.load %0[%arg1, %c0] : memref<4x8xf32>, vector<1x8xf32>
+  return %1 : vector<1x8xf32>
+}
+
+// CHECK-LABEL: func @no_fold_vector_load_expand_shape_leading_unit
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+//       CHECK:   memref.expand_shape %[[ARG0]]
+//   CHECK-NOT:   vector.shape_cast
+//       CHECK:   vector.load
+
+// -----
+
 func.func @fold_vector_maskedload_expand_shape(
   %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
   %c0 = arith.constant 0 : index
@@ -742,6 +795,22 @@ func.func @fold_vector_maskedload_expand_shape(
 
 // -----
 
+func.func @no_fold_vector_maskedload_expand_shape_high_rank_vector(
+  %arg0 : memref<32xf32>, %arg1 : index,
+  %mask : vector<2x2x2xi1>, %pass : vector<2x2x2xf32>) -> vector<2x2x2xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+  %1 = vector.maskedload %0[%arg1, %c0], %mask, %pass : memref<4x8xf32>, vector<2x2x2xi1>, vector<2x2x2xf32> into vector<2x2x2xf32>
+  return %1 : vector<2x2x2xf32>
+}
+
+// CHECK-LABEL: func @no_fold_vector_maskedload_expand_shape_high_rank_vector
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+//       CHECK:   %[[EXPAND:.*]] = memref.expand_shape %[[ARG0]]
+//       CHECK:   vector.maskedload %[[EXPAND]]
+
+// -----
+
 func.func @fold_vector_store_expand_shape(
   %arg0 : memref<32xf32>, %arg1 : index, %val : vector<8xf32>) {
   %c0 = arith.constant 0 : index
@@ -759,6 +828,24 @@ func.func @fold_vector_store_expand_shape(
 
 // -----
 
+// Folding this would require changing the vector op rank. That is handled by
+// vector drop-leading-unit-dim patterns, not by fold-memref-alias-ops.
+func.func @no_fold_vector_store_expand_shape_leading_unit(
+  %arg0 : memref<32xf32>, %arg1 : index, %val : vector<1x8xf32>) {
+  %c0 = arith.constant 0 : index
+  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+  vector.store %val, %0[%arg1, %c0] : memref<4x8xf32>, vector<1x8xf32>
+  return
+}
+
+// CHECK-LABEL: func @no_fold_vector_store_expand_shape_leading_unit
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+//       CHECK:   memref.expand_shape %[[ARG0]]
+//   CHECK-NOT:   vector.shape_cast
+//       CHECK:   vector.store
+
+// -----
+
 func.func @fold_vector_maskedstore_expand_shape(
   %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
   %c0 = arith.constant 0 : index
@@ -778,6 +865,44 @@ func.func @fold_vector_maskedstore_expand_shape(
 
 // -----
 
+func.func @fold_vector_expandload_expand_shape(
+  %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+  %1 = vector.expandload %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+  return %1 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @fold_vector_expandload_expand_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
+//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
+//       CHECK:   %[[C0:.*]] = arith.constant 0
+//       CHECK:   %[[IDX:.*]] = affine.linearize_index [%[[ARG1]], %[[C0]]] by (4, 8)
+//       CHECK:   vector.expandload %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]]
+
+// -----
+
+func.func @fold_vector_compressstore_expand_shape(
+  %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
+  %c0 = arith.constant 0 : index
+  %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+  vector.compressstore %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32>
+  return
+}
+
+// CHECK-LABEL: func @fold_vector_compressstore_expand_shape
+//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
+//  CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
+//       CHECK:   %[[C0:.*]] = arith.constant 0
+//       CHECK:   %[[IDX:.*]] = affine.linearize_index [%[[ARG1]], %[[C0]]] by (4, 8)
+//       CHECK:   vector.compressstore %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]]
+
+// -----
+
 func.func @fold_vector_transfer_read_expand_shape(
   %arg0 : memref<32xf32>, %arg1 : index) -> vector<8xf32> {
   %c0 = arith.constant 0 : index


        


More information about the Mlir-commits mailing list