[Mlir-commits] [mlir] 14a290a - [mlir][vector] Implement IndexedAccessOpInterface for load, store, etc. (#196216)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 8 08:39:10 PDT 2026
Author: Krzysztof Drewniak
Date: 2026-05-08T15:39:05Z
New Revision: 14a290ab3d013fa30e8efd767beb4c4f002c6175
URL: https://github.com/llvm/llvm-project/commit/14a290ab3d013fa30e8efd767beb4c4f002c6175
DIFF: https://github.com/llvm/llvm-project/commit/14a290ab3d013fa30e8efd767beb4c4f002c6175.diff
LOG: [mlir][vector] Implement IndexedAccessOpInterface for load, store, etc. (#196216)
This commit adds simple (not trying to account for unit dimensions that
could be cast away) implementations of IndexedAccessOpInterface to
low-level vector operations like vector.load and vector.store,
eliminating the need for the old-style code in FoldMemRefAliasOps.cpp.
After this commit, it'll be possible to migrate all the other
memref-rewriting passes (ExpandAddressComputation and FlattenMemRefs) to
use the interface, taking a bunch of dialect dependencies off of
memref/transforms.
Assisted-By: GPT 5.5 (pulled in old code, wrote some new tests)
Added:
mlir/include/mlir/Dialect/Vector/Transforms/IndexedAccessOpInterfaceImpl.h
mlir/lib/Dialect/Vector/Transforms/IndexedAccessOpInterfaceImpl.cpp
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.td
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
mlir/lib/RegisterAllDialects.cpp
mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.td b/mlir/include/mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.td
index 7fc69b4fabca6..0f1ef521afc57 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.td
@@ -53,7 +53,11 @@ def IndexedAccessOpInterface : OpInterface<"IndexedAccessOpInterface"> {
InterfaceMethod<
/*desc=*/[{
Return the shape of the portion of the memref that is being accessed by
- this operation, if known, ignoring leading unit dimensions.
+ this operation, if known. This shape describes the access dimensions
+ whose strides are semantically important for this operation.
+ Implementations shall omit dimensions whose strides do not affect the
+ operation semantics. (In particular, if an operation will access one
+ element of the base memref, this method should return `{}`.)
Reindexing transformations may not modify the *strides* of the trailing
N dimensions, where N is the size returned value, and should ensure that
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 8f4fa5ca6a844..28a8109cb59c0 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1658,6 +1658,7 @@ def Vector_TransferWriteOp :
let hasVerifier = 1;
}
+// Promises IndexedAccessOpInterface.
def Vector_LoadOp : Vector_Op<"load", [
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
@@ -1776,6 +1777,7 @@ def Vector_LoadOp : Vector_Op<"load", [
"$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
}
+// Promises IndexedAccessOpInterface.
def Vector_StoreOp : Vector_Op<"store", [
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
@@ -1883,6 +1885,7 @@ def Vector_StoreOp : Vector_Op<"store", [
"`:` type($base) `,` type($valueToStore)";
}
+// Promises IndexedAccessOpInterface.
def Vector_MaskedLoadOp :
Vector_Op<"maskedload", [
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
@@ -1978,6 +1981,7 @@ def Vector_MaskedLoadOp :
];
}
+// Promises IndexedAccessOpInterface.
def Vector_MaskedStoreOp :
Vector_Op<"maskedstore", [
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
@@ -2251,6 +2255,7 @@ def Vector_ScatterOp
}]>];
}
+// Promises IndexedAccessOpInterface.
def Vector_ExpandLoadOp :
Vector_Op<"expandload", [
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
@@ -2342,6 +2347,7 @@ def Vector_ExpandLoadOp :
];
}
+// Promises IndexedAccessOpInterface.
def Vector_CompressStoreOp :
Vector_Op<"compressstore", [
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/IndexedAccessOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Vector/Transforms/IndexedAccessOpInterfaceImpl.h
new file mode 100644
index 0000000000000..57fe661ad81f0
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/IndexedAccessOpInterfaceImpl.h
@@ -0,0 +1,21 @@
+//===- IndexedAccessOpInterfaceImpl.h ---------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_INDEXEDACCESSOPINTERFACEIMPL_H
+#define MLIR_DIALECT_VECTOR_TRANSFORMS_INDEXEDACCESSOPINTERFACEIMPL_H
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace vector {
+void registerIndexedAccessOpInterfaceExternalModels(DialectRegistry ®istry);
+} // namespace vector
+} // namespace mlir
+
+#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_INDEXEDACCESSOPINTERFACEIMPL_H
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index e36ddfa063e11..de7662753d142 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -25,7 +25,6 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
-#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <cstdint>
@@ -52,6 +51,8 @@ hasTrivialReassociationSuffix(ArrayRef<ReassociationIndices> reassocs,
int64_t n) {
if (n <= 0)
return true;
+ if (n > static_cast<int64_t>(reassocs.size()))
+ return false;
return llvm::all_of(
reassocs.take_back(n),
[&](const ReassociationIndices &indices) { return indices.size() == 1; });
@@ -60,89 +61,17 @@ hasTrivialReassociationSuffix(ArrayRef<ReassociationIndices> reassocs,
static bool hasTrailingUnitStrides(memref::SubViewOp subview, int64_t n) {
if (n <= 0)
return true;
- return llvm::all_of(subview.getStaticStrides().take_back(n),
- [](int64_t s) { return s == 1; });
+ ArrayRef<int64_t> strides = subview.getStaticStrides();
+ if (n > static_cast<int64_t>(strides.size()))
+ return false;
+ return llvm::all_of(strides.take_back(n), [](int64_t s) { return s == 1; });
}
-/// Helpers to access the memref operand for each op.
-template <typename LoadOrStoreOpTy>
-static Value getMemRefOperand(LoadOrStoreOpTy op) {
- return op.getMemref();
-}
-
-static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
-
-static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); }
-
-static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
-
-static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); }
-
//===----------------------------------------------------------------------===//
// Patterns
//===----------------------------------------------------------------------===//
namespace {
-/// Merges subview operation with load/transferRead operation.
-template <typename OpTy>
-class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
-public:
- using OpRewritePattern<OpTy>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(OpTy loadOp,
- PatternRewriter &rewriter) const override;
-};
-
-/// Merges expand_shape operation with load/transferRead operation.
-template <typename OpTy>
-class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
-public:
- using OpRewritePattern<OpTy>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(OpTy loadOp,
- PatternRewriter &rewriter) const override;
-};
-
-/// Merges collapse_shape operation with load/transferRead operation.
-template <typename OpTy>
-class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
-public:
- using OpRewritePattern<OpTy>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(OpTy loadOp,
- PatternRewriter &rewriter) const override;
-};
-
-/// Merges subview operation with store/transferWriteOp operation.
-template <typename OpTy>
-class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
-public:
- using OpRewritePattern<OpTy>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(OpTy storeOp,
- PatternRewriter &rewriter) const override;
-};
-
-/// Merges expand_shape operation with store/transferWriteOp operation.
-template <typename OpTy>
-class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
-public:
- using OpRewritePattern<OpTy>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(OpTy storeOp,
- PatternRewriter &rewriter) const override;
-};
-
-/// Merges collapse_shape operation with store/transferWriteOp operation.
-template <typename OpTy>
-class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
-public:
- using OpRewritePattern<OpTy>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(OpTy storeOp,
- PatternRewriter &rewriter) const override;
-};
-
/// Folds subview(subview(x)) to a single subview(x).
class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
public:
@@ -286,226 +215,6 @@ struct TransferOpOfCollapseShapeOpFolder final
};
} // namespace
-static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
- Operation *op,
- memref::SubViewOp subviewOp) {
- return success();
-}
-
-template <typename OpTy>
-LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
- OpTy loadOp, PatternRewriter &rewriter) const {
- auto subViewOp =
- getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
-
- if (!subViewOp)
- return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
-
- LogicalResult preconditionResult =
- preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp);
- if (failed(preconditionResult))
- return preconditionResult;
-
- SmallVector<Value> sourceIndices;
- affine::resolveIndicesIntoOpWithOffsetsAndStrides(
- rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
- subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
- loadOp.getIndices(), sourceIndices);
-
- llvm::TypeSwitch<Operation *, void>(loadOp)
- .Case([&](memref::LoadOp op) {
- rewriter.replaceOpWithNewOp<memref::LoadOp>(
- loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
- })
- .Case([&](vector::LoadOp op) {
- rewriter.replaceOpWithNewOp<vector::LoadOp>(
- op, op.getType(), subViewOp.getSource(), sourceIndices);
- })
- .Case([&](vector::MaskedLoadOp op) {
- rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
- op, op.getType(), subViewOp.getSource(), sourceIndices,
- op.getMask(), op.getPassThru());
- })
- .DefaultUnreachable("unexpected operation");
- return success();
-}
-
-template <typename OpTy>
-LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
- OpTy loadOp, PatternRewriter &rewriter) const {
- auto expandShapeOp =
- getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>();
-
- if (!expandShapeOp)
- return failure();
-
- SmallVector<Value> sourceIndices;
- // memref.load guarantees that indexes start inbounds while the vector
- // operations don't. This impacts if our linearization is `disjoint`
- resolveSourceIndicesExpandShape(loadOp.getLoc(), rewriter, expandShapeOp,
- loadOp.getIndices(), sourceIndices,
- isa<memref::LoadOp>(loadOp.getOperation()));
-
- return llvm::TypeSwitch<Operation *, LogicalResult>(loadOp)
- .Case([&](memref::LoadOp op) {
- rewriter.replaceOpWithNewOp<memref::LoadOp>(
- loadOp, expandShapeOp.getViewSource(), sourceIndices,
- op.getNontemporal());
- return success();
- })
- .Case([&](vector::LoadOp op) {
- rewriter.replaceOpWithNewOp<vector::LoadOp>(
- op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
- op.getNontemporal());
- return success();
- })
- .Case([&](vector::MaskedLoadOp op) {
- rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
- op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
- op.getMask(), op.getPassThru());
- return success();
- })
- .DefaultUnreachable("unexpected operation");
-}
-
-template <typename OpTy>
-LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
- OpTy loadOp, PatternRewriter &rewriter) const {
- auto collapseShapeOp = getMemRefOperand(loadOp)
- .template getDefiningOp<memref::CollapseShapeOp>();
-
- if (!collapseShapeOp)
- return failure();
-
- SmallVector<Value> sourceIndices;
- resolveSourceIndicesCollapseShape(loadOp.getLoc(), rewriter, collapseShapeOp,
- loadOp.getIndices(), sourceIndices);
- llvm::TypeSwitch<Operation *, void>(loadOp)
- .Case([&](memref::LoadOp op) {
- rewriter.replaceOpWithNewOp<memref::LoadOp>(
- loadOp, collapseShapeOp.getViewSource(), sourceIndices,
- op.getNontemporal());
- })
- .Case([&](vector::LoadOp op) {
- rewriter.replaceOpWithNewOp<vector::LoadOp>(
- op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
- op.getNontemporal());
- })
- .Case([&](vector::MaskedLoadOp op) {
- rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
- op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
- op.getMask(), op.getPassThru());
- })
- .DefaultUnreachable("unexpected operation");
- return success();
-}
-
-template <typename OpTy>
-LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
- OpTy storeOp, PatternRewriter &rewriter) const {
- auto subViewOp =
- getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
-
- if (!subViewOp)
- return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
-
- LogicalResult preconditionResult =
- preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp);
- if (failed(preconditionResult))
- return preconditionResult;
-
- SmallVector<Value> sourceIndices;
- affine::resolveIndicesIntoOpWithOffsetsAndStrides(
- rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
- subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
- storeOp.getIndices(), sourceIndices);
-
- llvm::TypeSwitch<Operation *, void>(storeOp)
- .Case([&](memref::StoreOp op) {
- rewriter.replaceOpWithNewOp<memref::StoreOp>(
- op, op.getValue(), subViewOp.getSource(), sourceIndices,
- op.getNontemporal());
- })
- .Case([&](vector::StoreOp op) {
- rewriter.replaceOpWithNewOp<vector::StoreOp>(
- op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
- })
- .Case([&](vector::MaskedStoreOp op) {
- rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
- op, subViewOp.getSource(), sourceIndices, op.getMask(),
- op.getValueToStore());
- })
- .DefaultUnreachable("unexpected operation");
- return success();
-}
-
-template <typename OpTy>
-LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
- OpTy storeOp, PatternRewriter &rewriter) const {
- auto expandShapeOp =
- getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>();
-
- if (!expandShapeOp)
- return failure();
-
- SmallVector<Value> sourceIndices;
- // memref.store guarantees that indexes start inbounds while the vector
- // operations don't. This impacts if our linearization is `disjoint`
- resolveSourceIndicesExpandShape(storeOp.getLoc(), rewriter, expandShapeOp,
- storeOp.getIndices(), sourceIndices,
- isa<memref::StoreOp>(storeOp.getOperation()));
- llvm::TypeSwitch<Operation *, void>(storeOp)
- .Case([&](memref::StoreOp op) {
- rewriter.replaceOpWithNewOp<memref::StoreOp>(
- storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
- sourceIndices, op.getNontemporal());
- })
- .Case([&](vector::StoreOp op) {
- rewriter.replaceOpWithNewOp<vector::StoreOp>(
- op, op.getValueToStore(), expandShapeOp.getViewSource(),
- sourceIndices, op.getNontemporal());
- })
- .Case([&](vector::MaskedStoreOp op) {
- rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
- op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
- op.getValueToStore());
- })
- .DefaultUnreachable("unexpected operation");
- return success();
-}
-
-template <typename OpTy>
-LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
- OpTy storeOp, PatternRewriter &rewriter) const {
- auto collapseShapeOp = getMemRefOperand(storeOp)
- .template getDefiningOp<memref::CollapseShapeOp>();
-
- if (!collapseShapeOp)
- return failure();
-
- SmallVector<Value> sourceIndices;
- resolveSourceIndicesCollapseShape(storeOp.getLoc(), rewriter, collapseShapeOp,
- storeOp.getIndices(), sourceIndices);
- llvm::TypeSwitch<Operation *, void>(storeOp)
- .Case([&](memref::StoreOp op) {
- rewriter.replaceOpWithNewOp<memref::StoreOp>(
- storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
- sourceIndices, op.getNontemporal());
- })
- .Case([&](vector::StoreOp op) {
- rewriter.replaceOpWithNewOp<vector::StoreOp>(
- op, op.getValueToStore(), collapseShapeOp.getViewSource(),
- sourceIndices, op.getNontemporal());
- })
- .Case([&](vector::MaskedStoreOp op) {
- rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
- op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
- op.getValueToStore());
- })
- .DefaultUnreachable("unexpected operation");
- return success();
-}
-
LogicalResult
AccessOpOfSubViewOpFolder::matchAndRewrite(memref::IndexedAccessOpInterface op,
PatternRewriter &rewriter) const {
@@ -849,27 +558,13 @@ LogicalResult TransferOpOfCollapseShapeOpFolder::matchAndRewrite(
}
void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
- patterns.add<
- // Interface-based patterns to which we will be migrating.
- AccessOpOfSubViewOpFolder, AccessOpOfExpandShapeOpFolder,
- AccessOpOfCollapseShapeOpFolder, IndexedMemCopyOpOfSubViewOpFolder,
- IndexedMemCopyOpOfExpandShapeOpFolder,
- IndexedMemCopyOpOfCollapseShapeOpFolder, TransferOpOfSubViewOpFolder,
- TransferOpOfExpandShapeOpFolder, TransferOpOfCollapseShapeOpFolder,
- // The old way of doing things. Don't add more of these.
- LoadOpOfSubViewOpFolder<vector::LoadOp>,
- LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
- StoreOpOfSubViewOpFolder<vector::StoreOp>,
- StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
- LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
- LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
- StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
- StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
- LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
- LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
- StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
- StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
- SubViewOfSubViewFolder>(patterns.getContext());
+ patterns
+ .add<AccessOpOfSubViewOpFolder, AccessOpOfExpandShapeOpFolder,
+ AccessOpOfCollapseShapeOpFolder, IndexedMemCopyOpOfSubViewOpFolder,
+ IndexedMemCopyOpOfExpandShapeOpFolder,
+ IndexedMemCopyOpOfCollapseShapeOpFolder, TransferOpOfSubViewOpFolder,
+ TransferOpOfExpandShapeOpFolder, TransferOpOfCollapseShapeOpFolder,
+ SubViewOfSubViewFolder>(patterns.getContext());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 780b0cbb36120..51be1e4431e70 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/UB/IR/UBMatchers.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -483,6 +484,9 @@ void VectorDialect::initialize() {
addInterfaces<VectorInlinerInterface>();
+ declarePromisedInterfaces<memref::IndexedAccessOpInterface, LoadOp, StoreOp,
+ MaskedLoadOp, MaskedStoreOp, ExpandLoadOp,
+ CompressStoreOp>();
declarePromisedInterfaces<bufferization::BufferizableOpInterface,
TransferReadOp, TransferWriteOp, GatherOp, MaskOp,
YieldOp>();
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 4e0f07af95984..112a1db6fe93b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRVectorTransforms
BufferizableOpInterfaceImpl.cpp
+ IndexedAccessOpInterfaceImpl.cpp
LowerVectorBitCast.cpp
LowerVectorBroadcast.cpp
LowerVectorContract.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/IndexedAccessOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/IndexedAccessOpInterfaceImpl.cpp
new file mode 100644
index 0000000000000..c91ea97a2f965
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/IndexedAccessOpInterfaceImpl.cpp
@@ -0,0 +1,101 @@
+//===- IndexedAccessOpInterfaceImpl.cpp -----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Implement IndexedAccessOpInterface on vector dialect operations with
+// %memref[%i, %j, ...] operands so generic memref-dialect passes can rewrite
+// their base/index pairs. Transfer ops keep their VectorTransferOpInterface
+// patterns; gather/scatter have tensor-or-memref bases and index-vector
+// operands that do not fit IndexedAccessOpInterface's rank-matched index
+// contract.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/Transforms/IndexedAccessOpInterfaceImpl.h"
+
+#include "mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::memref;
+
+namespace {
+/// Return true if this op has the memref semantics expected by this model.
+template <typename LoadStoreOp>
+bool hasMemrefSemantics(Operation *op) {
+ return llvm::isa<MemRefType>(cast<LoadStoreOp>(op).getBase().getType());
+}
+
+/// Return the vector shape whose access strides must be preserved, marking
+/// scalable dimensions as dynamic.
+SmallVector<int64_t> getAccessedVectorShape(VectorType vecTy) {
+ return llvm::map_to_vector(
+ llvm::zip_equal(vecTy.getShape(), vecTy.getScalableDims()), [](auto dim) {
+ auto [size, scalable] = dim;
+ return scalable ? ShapedType::kDynamic : size;
+ });
+}
+
+template <typename LoadStoreOp>
+struct VectorLoadStoreLikeOpImpl final
+ : IndexedAccessOpInterface::ExternalModel<
+ VectorLoadStoreLikeOpImpl<LoadStoreOp>, LoadStoreOp> {
+ TypedValue<MemRefType> getAccessedMemref(Operation *op) const {
+ return cast<LoadStoreOp>(op).getBase();
+ }
+
+ Operation::operand_range getIndices(Operation *op) const {
+ return cast<LoadStoreOp>(op).getIndices();
+ }
+
+ SmallVector<int64_t> getAccessedShape(Operation *op) const {
+ assert(hasMemrefSemantics<LoadStoreOp>(op) &&
+ "expected vector op with memref semantics");
+ return getAccessedVectorShape(cast<LoadStoreOp>(op).getVectorType());
+ }
+
+ std::optional<SmallVector<Value>>
+ updateMemrefAndIndices(Operation *op, RewriterBase &rewriter, Value newMemref,
+ ValueRange newIndices) const {
+ assert(hasMemrefSemantics<LoadStoreOp>(op) &&
+ "expected vector op with memref semantics");
+ assert(llvm::isa<MemRefType>(newMemref.getType()) &&
+ "expected replacement memref");
+ rewriter.modifyOpInPlace(op, [&]() {
+ auto concreteOp = cast<LoadStoreOp>(op);
+ concreteOp.getBaseMutable().assign(newMemref);
+ concreteOp.getIndicesMutable().assign(newIndices);
+ });
+ return std::nullopt;
+ }
+
+ // TODO: The various load and store operations, at the very least vector.load
+ // and vector.store, should be taught a starts-in-bounds attribute that would
+ // let us optimize index generation.
+ bool hasInboundsIndices(Operation *op) const {
+ assert(hasMemrefSemantics<LoadStoreOp>(op) &&
+ "expected vector op with memref semantics");
+ return false;
+ }
+};
+
+template <typename... Ops>
+static void attachAll(MLIRContext *ctx) {
+ (Ops::template attachInterface<VectorLoadStoreLikeOpImpl<Ops>>(*ctx), ...);
+}
+
+} // namespace
+
+void mlir::vector::registerIndexedAccessOpInterfaceExternalModels(
+ DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
+ attachAll<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
+ vector::MaskedStoreOp, vector::ExpandLoadOp,
+ vector::CompressStoreOp>(ctx);
+ });
+}
diff --git a/mlir/lib/RegisterAllDialects.cpp b/mlir/lib/RegisterAllDialects.cpp
index 589730b785133..01a7401db4710 100644
--- a/mlir/lib/RegisterAllDialects.cpp
+++ b/mlir/lib/RegisterAllDialects.cpp
@@ -95,6 +95,7 @@
#include "mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Vector/Transforms/IndexedAccessOpInterfaceImpl.h"
#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
#include "mlir/Dialect/X86/X86Dialect.h"
@@ -197,6 +198,7 @@ void mlir::registerAllDialects(DialectRegistry ®istry) {
tensor::registerValueBoundsOpInterfaceExternalModels(registry);
tosa::registerShardingInterfaceExternalModels(registry);
vector::registerBufferizableOpInterfaceExternalModels(registry);
+ vector::registerIndexedAccessOpInterfaceExternalModels(registry);
vector::registerSubsetOpInterfaceExternalModels(registry);
vector::registerValueBoundsOpInterfaceExternalModels(registry);
NVVM::registerNVVMTargetInterfaceExternalModels(registry);
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 50c7ebaff1e6a..6e2702d936ee0 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -642,6 +642,26 @@ func.func @fold_vector_load_subview(%src : memref<24x64xf32>,
// -----
+// TODO: This should fold, but implementing IndexedAccessOpInterface on vector.load
+// in a way that would allow the fold added complexity (emitting
+// `vector.shape_cast`s) that people wanted to keep out of the initial
+// implementation during previous discussions. (Note: this didn't work in the
+// pre-interface version of the pass either.)
+func.func @no_fold_scalar_equivalent_vector_load_subview(
+ %arg0 : memref<16xf32>, %off : index, %idx : index) -> vector<1xf32> {
+ %0 = memref.subview %arg0[%off][4][2] : memref<16xf32> to memref<4xf32, strided<[2], offset: ?>>
+ %1 = vector.load %0[%idx] : memref<4xf32, strided<[2], offset: ?>>, vector<1xf32>
+ return %1 : vector<1xf32>
+}
+
+// CHECK-LABEL: func @no_fold_scalar_equivalent_vector_load_subview
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<16xf32>
+// CHECK: %[[SUBVIEW:.*]] = memref.subview %[[ARG0]]
+// CHECK-NOT: vector.shape_cast
+// CHECK: vector.load %[[SUBVIEW]]
+
+// -----
+
func.func @fold_vector_maskedload_subview(
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> vector<32xf32> {
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
@@ -659,6 +679,21 @@ func.func @fold_vector_maskedload_subview(
// -----
+func.func @no_fold_vector_maskedload_subview_high_rank_vector(
+ %arg0 : memref<8xf32>, %idx : index,
+ %mask : vector<2x2x2xi1>, %pass : vector<2x2x2xf32>) -> vector<2x2x2xf32> {
+ %0 = memref.subview %arg0[%idx][1][1] : memref<8xf32> to memref<1xf32, strided<[1], offset: ?>>
+ %1 = vector.maskedload %0[%idx], %mask, %pass : memref<1xf32, strided<[1], offset: ?>>, vector<2x2x2xi1>, vector<2x2x2xf32> into vector<2x2x2xf32>
+ return %1 : vector<2x2x2xf32>
+}
+
+// CHECK-LABEL: func @no_fold_vector_maskedload_subview_high_rank_vector
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<8xf32>
+// CHECK: %[[SUBVIEW:.*]] = memref.subview %[[ARG0]]
+// CHECK: vector.maskedload %[[SUBVIEW]]
+
+// -----
+
func.func @fold_vector_store_subview(%src : memref<24x64xf32>,
%off1 : index,
%off2 : index,
@@ -723,6 +758,24 @@ func.func @fold_vector_load_expand_shape(
// -----
+// Folding this would require changing the vector op rank. That is handled by
+// vector drop-leading-unit-dim patterns, not by fold-memref-alias-ops.
+func.func @no_fold_vector_load_expand_shape_leading_unit(
+ %arg0 : memref<32xf32>, %arg1 : index) -> vector<1x8xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+ %1 = vector.load %0[%arg1, %c0] : memref<4x8xf32>, vector<1x8xf32>
+ return %1 : vector<1x8xf32>
+}
+
+// CHECK-LABEL: func @no_fold_vector_load_expand_shape_leading_unit
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+// CHECK: memref.expand_shape %[[ARG0]]
+// CHECK-NOT: vector.shape_cast
+// CHECK: vector.load
+
+// -----
+
func.func @fold_vector_maskedload_expand_shape(
%arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
%c0 = arith.constant 0 : index
@@ -742,6 +795,22 @@ func.func @fold_vector_maskedload_expand_shape(
// -----
+func.func @no_fold_vector_maskedload_expand_shape_high_rank_vector(
+ %arg0 : memref<32xf32>, %arg1 : index,
+ %mask : vector<2x2x2xi1>, %pass : vector<2x2x2xf32>) -> vector<2x2x2xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+ %1 = vector.maskedload %0[%arg1, %c0], %mask, %pass : memref<4x8xf32>, vector<2x2x2xi1>, vector<2x2x2xf32> into vector<2x2x2xf32>
+ return %1 : vector<2x2x2xf32>
+}
+
+// CHECK-LABEL: func @no_fold_vector_maskedload_expand_shape_high_rank_vector
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+// CHECK: %[[EXPAND:.*]] = memref.expand_shape %[[ARG0]]
+// CHECK: vector.maskedload %[[EXPAND]]
+
+// -----
+
func.func @fold_vector_store_expand_shape(
%arg0 : memref<32xf32>, %arg1 : index, %val : vector<8xf32>) {
%c0 = arith.constant 0 : index
@@ -759,6 +828,24 @@ func.func @fold_vector_store_expand_shape(
// -----
+// Folding this would require changing the vector op rank. That is handled by
+// vector drop-leading-unit-dim patterns, not by fold-memref-alias-ops.
+func.func @no_fold_vector_store_expand_shape_leading_unit(
+ %arg0 : memref<32xf32>, %arg1 : index, %val : vector<1x8xf32>) {
+ %c0 = arith.constant 0 : index
+ %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+ vector.store %val, %0[%arg1, %c0] : memref<4x8xf32>, vector<1x8xf32>
+ return
+}
+
+// CHECK-LABEL: func @no_fold_vector_store_expand_shape_leading_unit
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+// CHECK: memref.expand_shape %[[ARG0]]
+// CHECK-NOT: vector.shape_cast
+// CHECK: vector.store
+
+// -----
+
func.func @fold_vector_maskedstore_expand_shape(
%arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
%c0 = arith.constant 0 : index
@@ -778,6 +865,44 @@ func.func @fold_vector_maskedstore_expand_shape(
// -----
+func.func @fold_vector_expandload_expand_shape(
+ %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+ %1 = vector.expandload %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
+ return %1 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @fold_vector_expandload_expand_shape
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0
+// CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG1]], %[[C0]]] by (4, 8)
+// CHECK: vector.expandload %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]]
+
+// -----
+
+func.func @fold_vector_compressstore_expand_shape(
+ %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
+ %c0 = arith.constant 0 : index
+ %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
+ vector.compressstore %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32>
+ return
+}
+
+// CHECK-LABEL: func @fold_vector_compressstore_expand_shape
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<8xi1>
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<8xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0
+// CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG1]], %[[C0]]] by (4, 8)
+// CHECK: vector.compressstore %[[ARG0]][%[[IDX]]], %[[ARG3]], %[[ARG4]]
+
+// -----
+
func.func @fold_vector_transfer_read_expand_shape(
%arg0 : memref<32xf32>, %arg1 : index) -> vector<8xf32> {
%c0 = arith.constant 0 : index
More information about the Mlir-commits
mailing list