[Mlir-commits] [mlir] 1b002d2 - Fold memref.expand_shape and memref.collapse_shape ops
Uday Bondhugula
llvmlistbot at llvm.org
Sat Aug 27 18:32:52 PDT 2022
Author: Arnab Dutta
Date: 2022-08-28T06:56:06+05:30
New Revision: 1b002d27683522e15d6ef3bccb74e751bbf56e74
URL: https://github.com/llvm/llvm-project/commit/1b002d27683522e15d6ef3bccb74e751bbf56e74
DIFF: https://github.com/llvm/llvm-project/commit/1b002d27683522e15d6ef3bccb74e751bbf56e74.diff
LOG: Fold memref.expand_shape and memref.collapse_shape ops
Fold memref.expand_shape and memref.collapse_shape ops into their
memref/affine load/store ops.
Reviewed By: bondhugula, nicolasvasilache
Differential Revision: https://reviews.llvm.org/D128986
Added:
mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
Modified:
mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
mlir/include/mlir/Dialect/Utils/IndexingUtils.h
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
mlir/lib/Dialect/Utils/IndexingUtils.cpp
mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
Removed:
mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
mlir/test/Dialect/MemRef/fold-subview-ops.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index 4c7781d06a2ac..82c2b5bceb9a2 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -39,9 +39,9 @@ class AllocOp;
/// Collects a set of patterns to rewrite ops within the memref dialect.
void populateExpandOpsPatterns(RewritePatternSet &patterns);
-/// Appends patterns for folding memref.subview ops into consumer load/store ops
-/// into `patterns`.
-void populateFoldSubViewOpPatterns(RewritePatternSet &patterns);
+/// Appends patterns for folding memref aliasing ops into consumer load/store
+/// ops into `patterns`.
+void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns);
/// Appends patterns that resolve `memref.dim` operations with values that are
/// defined by operations that implement the
@@ -91,9 +91,9 @@ LogicalResult multiBuffer(memref::AllocOp allocOp, unsigned multiplier);
/// `memref_reinterpret_cast`.
std::unique_ptr<Pass> createExpandOpsPass();
-/// Creates an operation pass to fold memref.subview ops into consumer
+/// Creates an operation pass to fold memref aliasing ops into consumer
/// load/store ops into `patterns`.
-std::unique_ptr<Pass> createFoldSubViewOpsPass();
+std::unique_ptr<Pass> createFoldMemRefAliasOpsPass();
/// Creates an interprocedural pass to normalize memrefs to have a trivial
/// (identity) layout map.
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index ee769c11590b9..5ac124a0fcaa8 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -16,13 +16,13 @@ def ExpandOps : Pass<"memref-expand"> {
let constructor = "mlir::memref::createExpandOpsPass()";
}
-def FoldSubViewOps : Pass<"fold-memref-subview-ops"> {
- let summary = "Fold memref.subview ops into consumer load/store ops";
+def FoldMemRefAliasOps : Pass<"fold-memref-alias-ops"> {
+ let summary = "Fold memref alias ops into consumer load/store ops";
let description = [{
- The pass folds loading/storing from/to subview ops to loading/storing
+ The pass folds loading/storing from/to memref aliasing ops to loading/storing
from/to the original memref.
}];
- let constructor = "mlir::memref::createFoldSubViewOpsPass()";
+ let constructor = "mlir::memref::createFoldMemRefAliasOpsPass()";
let dependentDialects = [
"AffineDialect", "memref::MemRefDialect", "vector::VectorDialect"
];
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 3f2dd00c696f8..c1857462d4c67 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -14,6 +14,7 @@
#ifndef MLIR_DIALECT_UTILS_INDEXINGUTILS_H
#define MLIR_DIALECT_UTILS_INDEXINGUTILS_H
+#include "mlir/IR/Builders.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
@@ -47,6 +48,15 @@ void applyPermutationToVector(SmallVector<T, N> &inVec,
SmallVector<int64_t, 4> getI64SubArray(ArrayAttr arrayAttr,
unsigned dropFront = 0,
unsigned dropBack = 0);
+
+/// Computes and returns linearized affine expression w.r.t. `basis`.
+mlir::AffineExpr getLinearAffineExpr(ArrayRef<int64_t> basis, mlir::Builder &b);
+
+/// Given the strides in the dimension space, returns the affine expressions for
+/// vector-space offsets in each dimension for a de-linearized index.
+SmallVector<mlir::AffineExpr, 4>
+getDelinearizedAffineExpr(ArrayRef<int64_t> strides, mlir::Builder &b);
+
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_INDEXINGUTILS_H
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index aa6624f07a2dc..b7783c601a92b 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1486,6 +1486,13 @@ def Vector_TransferWriteOp :
"ValueRange":$indices,
CArg<"Optional<ArrayRef<bool>>", "::llvm::None">:$inBounds)>,
];
+
+ let extraClassDeclaration = [{
+ /// This method is added to maintain uniformity with load/store
+ /// ops of other dialects.
+ Value getValue() { return getVector(); }
+ }];
+
let hasFolder = 1;
let hasCanonicalizer = 1;
let hasCustomAssemblyFormat = 1;
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index 1f8167cc318e5..f85b6e50e91e4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -1,7 +1,7 @@
add_mlir_dialect_library(MLIRMemRefTransforms
ComposeSubView.cpp
ExpandOps.cpp
- FoldSubViewOps.cpp
+ FoldMemRefAliasOps.cpp
MultiBuffer.cpp
NormalizeMemRefs.cpp
ResolveShapedTypeResultDims.cpp
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
new file mode 100644
index 0000000000000..3ce18ec442c04
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -0,0 +1,562 @@
+//===- FoldMemRefAliasOps.cpp - Fold memref alias ops -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transformation pass folds loading/storing from/to subview ops into
+// loading/storing from/to the original memref.
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+/// Given the 'indices' of a load/store operation where the memref is a result
+/// of a expand_shape op, returns the indices w.r.t to the source memref of the
+/// expand_shape op. For example
+///
+/// %0 = ... : memref<12x42xf32>
+/// %1 = memref.expand_shape %0 [[0, 1], [2]]
+/// : memref<12x42xf32> into memref<2x6x42xf32>
+/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
+///
+/// could be folded into
+///
+/// %2 = load %0[6 * i1 + i2, %i3] :
+/// memref<12x42xf32>
+static LogicalResult
+resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
+ memref::ExpandShapeOp expandShapeOp,
+ ValueRange indices,
+ SmallVectorImpl<Value> &sourceIndices) {
+ for (SmallVector<int64_t, 2> groups :
+ expandShapeOp.getReassociationIndices()) {
+ assert(!groups.empty() && "association indices groups cannot be empty");
+ unsigned groupSize = groups.size();
+ SmallVector<int64_t> suffixProduct(groupSize);
+ // Calculate suffix product of dimension sizes for all dimensions of expand
+ // shape op result.
+ suffixProduct[groupSize - 1] = 1;
+ for (unsigned i = groupSize - 1; i > 0; i--)
+ suffixProduct[i - 1] =
+ suffixProduct[i] *
+ expandShapeOp.getType().cast<MemRefType>().getDimSize(groups[i]);
+ SmallVector<Value> dynamicIndices(groupSize);
+ for (unsigned i = 0; i < groupSize; i++)
+ dynamicIndices[i] = indices[groups[i]];
+ // Construct the expression for the index value w.r.t to expand shape op
+ // source corresponding the indices wrt to expand shape op result.
+ AffineExpr srcIndexExpr = getLinearAffineExpr(suffixProduct, rewriter);
+ sourceIndices.push_back(rewriter.create<AffineApplyOp>(
+ loc,
+ AffineMap::get(/*numDims=*/groupSize, /*numSymbols=*/0, srcIndexExpr),
+ dynamicIndices));
+ }
+ return success();
+}
+
+/// Given the 'indices' of a load/store operation where the memref is a result
+/// of a collapse_shape op, returns the indices w.r.t to the source memref of
+/// the collapse_shape op. For example
+///
+/// %0 = ... : memref<2x6x42xf32>
+/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
+/// : memref<2x6x42xf32> into memref<12x42xf32>
+/// %2 = load %1[%i1, %i2] : memref<12x42xf32>
+///
+/// could be folded into
+///
+/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
+/// memref<2x6x42xf32>
+static LogicalResult
+resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
+ memref::CollapseShapeOp collapseShapeOp,
+ ValueRange indices,
+ SmallVectorImpl<Value> &sourceIndices) {
+ unsigned cnt = 0;
+ SmallVector<Value> tmp(indices.size());
+ SmallVector<Value> dynamicIndices;
+ for (SmallVector<int64_t, 2> groups :
+ collapseShapeOp.getReassociationIndices()) {
+ assert(!groups.empty() && "association indices groups cannot be empty");
+ dynamicIndices.push_back(indices[cnt++]);
+ unsigned groupSize = groups.size();
+ SmallVector<int64_t> suffixProduct(groupSize);
+ // Calculate suffix product for all collapse op source dimension sizes.
+ suffixProduct[groupSize - 1] = 1;
+ for (unsigned i = groupSize - 1; i > 0; i--)
+ suffixProduct[i - 1] =
+ suffixProduct[i] * collapseShapeOp.getSrcType().getDimSize(groups[i]);
+ // Derive the index values along all dimensions of the source corresponding
+ // to the index wrt to collapsed shape op output.
+ SmallVector<AffineExpr, 4> srcIndexExpr =
+ getDelinearizedAffineExpr(suffixProduct, rewriter);
+ for (unsigned i = 0; i < groupSize; i++)
+ sourceIndices.push_back(rewriter.create<AffineApplyOp>(
+ loc, AffineMap::get(/*numDims=*/1, /*numSymbols=*/0, srcIndexExpr[i]),
+ dynamicIndices));
+ dynamicIndices.clear();
+ }
+ if (collapseShapeOp.getReassociationIndices().empty()) {
+ auto zeroAffineMap = rewriter.getConstantAffineMap(0);
+ unsigned srcRank =
+ collapseShapeOp.getViewSource().getType().cast<MemRefType>().getRank();
+ for (unsigned i = 0; i < srcRank; i++)
+ sourceIndices.push_back(
+ rewriter.create<AffineApplyOp>(loc, zeroAffineMap, dynamicIndices));
+ }
+ return success();
+}
+
+/// Given the 'indices' of an load/store operation where the memref is a result
+/// of a subview op, returns the indices w.r.t to the source memref of the
+/// subview op. For example
+///
+/// %0 = ... : memref<12x42xf32>
+/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
+/// memref<4x4xf32, offset=?, strides=[?, ?]>
+/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
+///
+/// could be folded into
+///
+/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
+/// memref<12x42xf32>
+static LogicalResult
+resolveSourceIndicesSubView(Location loc, PatternRewriter &rewriter,
+ memref::SubViewOp subViewOp, ValueRange indices,
+ SmallVectorImpl<Value> &sourceIndices) {
+ SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets();
+ SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
+ SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
+
+ SmallVector<Value> useIndices;
+ // Check if this is rank-reducing case. Then for every unit-dim size add a
+ // zero to the indices.
+ unsigned resultDim = 0;
+ llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
+ for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) {
+ if (unusedDims.test(dim))
+ useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
+ else
+ useIndices.push_back(indices[resultDim++]);
+ }
+ if (useIndices.size() != mixedOffsets.size())
+ return failure();
+ sourceIndices.resize(useIndices.size());
+ for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
+ SmallVector<Value> dynamicOperands;
+ AffineExpr expr = rewriter.getAffineDimExpr(0);
+ unsigned numSymbols = 0;
+ dynamicOperands.push_back(useIndices[index]);
+
+ // Multiply the stride;
+ if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) {
+ expr = expr * attr.cast<IntegerAttr>().getInt();
+ } else {
+ dynamicOperands.push_back(mixedStrides[index].get<Value>());
+ expr = expr * rewriter.getAffineSymbolExpr(numSymbols++);
+ }
+
+ // Add the offset.
+ if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) {
+ expr = expr + attr.cast<IntegerAttr>().getInt();
+ } else {
+ dynamicOperands.push_back(mixedOffsets[index].get<Value>());
+ expr = expr + rewriter.getAffineSymbolExpr(numSymbols++);
+ }
+ Location loc = subViewOp.getLoc();
+ sourceIndices[index] = rewriter.create<AffineApplyOp>(
+ loc, AffineMap::get(1, numSymbols, expr), dynamicOperands);
+ }
+ return success();
+}
+
+/// Helpers to access the memref operand for each op.
+template <typename LoadOrStoreOpTy>
+static Value getMemRefOperand(LoadOrStoreOpTy op) {
+ return op.getMemref();
+}
+
+static Value getMemRefOperand(vector::TransferReadOp op) {
+ return op.getSource();
+}
+
+static Value getMemRefOperand(vector::TransferWriteOp op) {
+ return op.getSource();
+}
+
+/// Given the permutation map of the original
+/// `vector.transfer_read`/`vector.transfer_write` operations compute the
+/// permutation map to use after the subview is folded with it.
+static AffineMapAttr getPermutationMapAttr(MLIRContext *context,
+ memref::SubViewOp subViewOp,
+ AffineMap currPermutationMap) {
+ llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
+ SmallVector<AffineExpr> exprs;
+ int64_t sourceRank = subViewOp.getSourceType().getRank();
+ for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
+ if (unusedDims.test(dim))
+ continue;
+ exprs.push_back(getAffineDimExpr(dim, context));
+ }
+ auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context);
+ return AffineMapAttr::get(
+ currPermutationMap.compose(resultDimToSourceDimMap));
+}
+
+//===----------------------------------------------------------------------===//
+// Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Merges subview operation with load/transferRead operation.
+template <typename OpTy>
+class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
+public:
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy loadOp,
+ PatternRewriter &rewriter) const override;
+};
+
+/// Merges expand_shape operation with load/transferRead operation.
+template <typename OpTy>
+class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
+public:
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy loadOp,
+ PatternRewriter &rewriter) const override;
+};
+
+/// Merges collapse_shape operation with load/transferRead operation.
+template <typename OpTy>
+class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
+public:
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy loadOp,
+ PatternRewriter &rewriter) const override;
+};
+
+/// Merges subview operation with store/transferWriteOp operation.
+template <typename OpTy>
+class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
+public:
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy storeOp,
+ PatternRewriter &rewriter) const override;
+};
+
+/// Merges expand_shape operation with store/transferWriteOp operation.
+template <typename OpTy>
+class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
+public:
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy storeOp,
+ PatternRewriter &rewriter) const override;
+};
+
+/// Merges collapse_shape operation with store/transferWriteOp operation.
+template <typename OpTy>
+class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
+public:
+ using OpRewritePattern<OpTy>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpTy storeOp,
+ PatternRewriter &rewriter) const override;
+};
+
+} // namespace
+
+static SmallVector<Value>
+calculateExpandedAccessIndices(AffineMap affineMap, SmallVector<Value> indices,
+ Location loc, PatternRewriter &rewriter) {
+ SmallVector<Value> expandedIndices;
+ for (unsigned i = 0, e = affineMap.getNumResults(); i < e; i++)
+ expandedIndices.push_back(
+ rewriter.create<AffineApplyOp>(loc, affineMap.getSubMap({i}), indices));
+ return expandedIndices;
+}
+
+template <typename OpTy>
+LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
+ OpTy loadOp, PatternRewriter &rewriter) const {
+ auto subViewOp =
+ getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
+
+ if (!subViewOp)
+ return failure();
+
+ SmallVector<Value> indices(loadOp.getIndices().begin(),
+ loadOp.getIndices().end());
+ // For affine ops, we need to apply the map to get the operands to get the
+ // "actual" indices.
+ if (auto affineLoadOp = dyn_cast<AffineLoadOp>(loadOp.getOperation())) {
+ AffineMap affineMap = affineLoadOp.getAffineMap();
+ auto expandedIndices = calculateExpandedAccessIndices(
+ affineMap, indices, loadOp.getLoc(), rewriter);
+ indices.assign(expandedIndices.begin(), expandedIndices.end());
+ }
+ SmallVector<Value, 4> sourceIndices;
+ if (failed(resolveSourceIndicesSubView(loadOp.getLoc(), rewriter, subViewOp,
+ indices, sourceIndices)))
+ return failure();
+ llvm::TypeSwitch<Operation *, void>(loadOp)
+ .Case<AffineLoadOp, memref::LoadOp>([&](auto op) {
+ rewriter.replaceOpWithNewOp<decltype(op)>(loadOp, subViewOp.source(),
+ sourceIndices);
+ })
+ .Case([&](vector::TransferReadOp transferReadOp) {
+ if (transferReadOp.getTransferRank() == 0) {
+ // TODO: Propagate the error.
+ return;
+ }
+ rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
+ transferReadOp, transferReadOp.getVectorType(), subViewOp.source(),
+ sourceIndices,
+ getPermutationMapAttr(rewriter.getContext(), subViewOp,
+ transferReadOp.getPermutationMap()),
+ transferReadOp.getPadding(),
+ /*mask=*/Value(), transferReadOp.getInBoundsAttr());
+ })
+ .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
+ return success();
+}
+
+template <typename OpTy>
+LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
+ OpTy loadOp, PatternRewriter &rewriter) const {
+ auto expandShapeOp =
+ getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>();
+
+ if (!expandShapeOp)
+ return failure();
+
+ SmallVector<Value> indices(loadOp.getIndices().begin(),
+ loadOp.getIndices().end());
+ // For affine ops, we need to apply the map to get the operands to get the
+ // "actual" indices.
+ if (auto affineLoadOp = dyn_cast<AffineLoadOp>(loadOp.getOperation())) {
+ AffineMap affineMap = affineLoadOp.getAffineMap();
+ auto expandedIndices = calculateExpandedAccessIndices(
+ affineMap, indices, loadOp.getLoc(), rewriter);
+ indices.assign(expandedIndices.begin(), expandedIndices.end());
+ }
+ SmallVector<Value, 4> sourceIndices;
+ if (failed(resolveSourceIndicesExpandShape(
+ loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
+ return failure();
+ llvm::TypeSwitch<Operation *, void>(loadOp)
+ .Case<AffineLoadOp, memref::LoadOp>([&](auto op) {
+ rewriter.replaceOpWithNewOp<decltype(op)>(
+ loadOp, expandShapeOp.getViewSource(), sourceIndices);
+ })
+ .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
+ return success();
+}
+
+template <typename OpTy>
+LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
+ OpTy loadOp, PatternRewriter &rewriter) const {
+ auto collapseShapeOp = getMemRefOperand(loadOp)
+ .template getDefiningOp<memref::CollapseShapeOp>();
+
+ if (!collapseShapeOp)
+ return failure();
+
+ SmallVector<Value> indices(loadOp.getIndices().begin(),
+ loadOp.getIndices().end());
+ // For affine ops, we need to apply the map to get the operands to get the
+ // "actual" indices.
+ if (auto affineLoadOp = dyn_cast<AffineLoadOp>(loadOp.getOperation())) {
+ AffineMap affineMap = affineLoadOp.getAffineMap();
+ auto expandedIndices = calculateExpandedAccessIndices(
+ affineMap, indices, loadOp.getLoc(), rewriter);
+ indices.assign(expandedIndices.begin(), expandedIndices.end());
+ }
+ SmallVector<Value, 4> sourceIndices;
+ if (failed(resolveSourceIndicesCollapseShape(
+ loadOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
+ return failure();
+ llvm::TypeSwitch<Operation *, void>(loadOp)
+ .Case<AffineLoadOp, memref::LoadOp>([&](auto op) {
+ rewriter.replaceOpWithNewOp<decltype(op)>(
+ loadOp, collapseShapeOp.getViewSource(), sourceIndices);
+ })
+ .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
+ return success();
+}
+
+template <typename OpTy>
+LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
+ OpTy storeOp, PatternRewriter &rewriter) const {
+ auto subViewOp =
+ getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
+
+ if (!subViewOp)
+ return failure();
+
+ SmallVector<Value> indices(storeOp.getIndices().begin(),
+ storeOp.getIndices().end());
+ // For affine ops, we need to apply the map to get the operands to get the
+ // "actual" indices.
+ if (auto affineStoreOp = dyn_cast<AffineStoreOp>(storeOp.getOperation())) {
+ AffineMap affineMap = affineStoreOp.getAffineMap();
+ auto expandedIndices = calculateExpandedAccessIndices(
+ affineMap, indices, storeOp.getLoc(), rewriter);
+ indices.assign(expandedIndices.begin(), expandedIndices.end());
+ }
+ SmallVector<Value, 4> sourceIndices;
+ if (failed(resolveSourceIndicesSubView(storeOp.getLoc(), rewriter, subViewOp,
+ indices, sourceIndices)))
+ return failure();
+ llvm::TypeSwitch<Operation *, void>(storeOp)
+ .Case<AffineStoreOp, memref::StoreOp>([&](auto op) {
+ rewriter.replaceOpWithNewOp<decltype(op)>(
+ storeOp, storeOp.getValue(), subViewOp.source(), sourceIndices);
+ })
+ .Case([&](vector::TransferWriteOp op) {
+ // TODO: support 0-d corner case.
+ if (op.getTransferRank() == 0)
+ return;
+ rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+ op, op.getValue(), subViewOp.source(), sourceIndices,
+ getPermutationMapAttr(rewriter.getContext(), subViewOp,
+ op.getPermutationMap()),
+ op.getInBoundsAttr());
+ })
+ .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
+ return success();
+}
+
+template <typename OpTy>
+LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
+ OpTy storeOp, PatternRewriter &rewriter) const {
+ auto expandShapeOp =
+ getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>();
+
+ if (!expandShapeOp)
+ return failure();
+
+ SmallVector<Value> indices(storeOp.getIndices().begin(),
+ storeOp.getIndices().end());
+ // For affine ops, we need to apply the map to get the operands to get the
+ // "actual" indices.
+ if (auto affineStoreOp = dyn_cast<AffineStoreOp>(storeOp.getOperation())) {
+ AffineMap affineMap = affineStoreOp.getAffineMap();
+ auto expandedIndices = calculateExpandedAccessIndices(
+ affineMap, indices, storeOp.getLoc(), rewriter);
+ indices.assign(expandedIndices.begin(), expandedIndices.end());
+ }
+ SmallVector<Value, 4> sourceIndices;
+ if (failed(resolveSourceIndicesExpandShape(
+ storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
+ return failure();
+ llvm::TypeSwitch<Operation *, void>(storeOp)
+ .Case<AffineStoreOp, memref::StoreOp>([&](auto op) {
+ rewriter.replaceOpWithNewOp<decltype(op)>(storeOp, storeOp.getValue(),
+ expandShapeOp.getViewSource(),
+ sourceIndices);
+ })
+ .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
+ return success();
+}
+
+template <typename OpTy>
+LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
+ OpTy storeOp, PatternRewriter &rewriter) const {
+ auto collapseShapeOp = getMemRefOperand(storeOp)
+ .template getDefiningOp<memref::CollapseShapeOp>();
+
+ if (!collapseShapeOp)
+ return failure();
+
+ SmallVector<Value> indices(storeOp.getIndices().begin(),
+ storeOp.getIndices().end());
+ // For affine ops, we need to apply the map to get the operands to get the
+ // "actual" indices.
+ if (auto affineStoreOp = dyn_cast<AffineStoreOp>(storeOp.getOperation())) {
+ AffineMap affineMap = affineStoreOp.getAffineMap();
+ auto expandedIndices = calculateExpandedAccessIndices(
+ affineMap, indices, storeOp.getLoc(), rewriter);
+ indices.assign(expandedIndices.begin(), expandedIndices.end());
+ }
+ SmallVector<Value, 4> sourceIndices;
+ if (failed(resolveSourceIndicesCollapseShape(
+ storeOp.getLoc(), rewriter, collapseShapeOp, indices, sourceIndices)))
+ return failure();
+ llvm::TypeSwitch<Operation *, void>(storeOp)
+ .Case<AffineStoreOp, memref::StoreOp>([&](auto op) {
+ rewriter.replaceOpWithNewOp<decltype(op)>(
+ storeOp, storeOp.getValue(), collapseShapeOp.getViewSource(),
+ sourceIndices);
+ })
+ .Default([](Operation *) { llvm_unreachable("unexpected operation."); });
+ return success();
+}
+
+void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
+ patterns.add<LoadOpOfSubViewOpFolder<AffineLoadOp>,
+ LoadOpOfSubViewOpFolder<memref::LoadOp>,
+ LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
+ StoreOpOfSubViewOpFolder<AffineStoreOp>,
+ StoreOpOfSubViewOpFolder<memref::StoreOp>,
+ StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
+ LoadOpOfExpandShapeOpFolder<AffineLoadOp>,
+ LoadOpOfExpandShapeOpFolder<memref::LoadOp>,
+ StoreOpOfExpandShapeOpFolder<AffineStoreOp>,
+ StoreOpOfExpandShapeOpFolder<memref::StoreOp>,
+ LoadOpOfCollapseShapeOpFolder<AffineLoadOp>,
+ LoadOpOfCollapseShapeOpFolder<memref::LoadOp>,
+ StoreOpOfCollapseShapeOpFolder<AffineStoreOp>,
+ StoreOpOfCollapseShapeOpFolder<memref::StoreOp>>(
+ patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// Pass registration
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+#define GEN_PASS_CLASSES
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+
+struct FoldMemRefAliasOpsPass final
+ : public FoldMemRefAliasOpsBase<FoldMemRefAliasOpsPass> {
+ void runOnOperation() override;
+};
+
+} // namespace
+
+void FoldMemRefAliasOpsPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ memref::populateFoldMemRefAliasOpPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
+ std::move(patterns));
+}
+
+std::unique_ptr<Pass> memref::createFoldMemRefAliasOpsPass() {
+ return std::make_unique<FoldMemRefAliasOpsPass>();
+}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
deleted file mode 100644
index 85d28ee5a020f..0000000000000
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldSubViewOps.cpp
+++ /dev/null
@@ -1,276 +0,0 @@
-//===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This transformation pass folds loading/storing from/to subview ops into
-// loading/storing from/to the original memref.
-//
-//===----------------------------------------------------------------------===//
-
-#include "PassDetail.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/MemRef/Transforms/Passes.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/ADT/SmallBitVector.h"
-
-using namespace mlir;
-
-//===----------------------------------------------------------------------===//
-// Utility functions
-//===----------------------------------------------------------------------===//
-
-/// Given the 'indices' of an load/store operation where the memref is a result
-/// of a subview op, returns the indices w.r.t to the source memref of the
-/// subview op. For example
-///
-/// %0 = ... : memref<12x42xf32>
-/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
-/// memref<4x4xf32, offset=?, strides=[?, ?]>
-/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
-///
-/// could be folded into
-///
-/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
-/// memref<12x42xf32>
-static LogicalResult
-resolveSourceIndices(Location loc, PatternRewriter &rewriter,
- memref::SubViewOp subViewOp, ValueRange indices,
- SmallVectorImpl<Value> &sourceIndices) {
- SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets();
- SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
- SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
-
- SmallVector<Value> useIndices;
- // Check if this is rank-reducing case. Then for every unit-dim size add a
- // zero to the indices.
- unsigned resultDim = 0;
- llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
- for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) {
- if (unusedDims.test(dim))
- useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
- else
- useIndices.push_back(indices[resultDim++]);
- }
- if (useIndices.size() != mixedOffsets.size())
- return failure();
- sourceIndices.resize(useIndices.size());
- for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
- SmallVector<Value> dynamicOperands;
- AffineExpr expr = rewriter.getAffineDimExpr(0);
- unsigned numSymbols = 0;
- dynamicOperands.push_back(useIndices[index]);
-
- // Multiply the stride;
- if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) {
- expr = expr * attr.cast<IntegerAttr>().getInt();
- } else {
- dynamicOperands.push_back(mixedStrides[index].get<Value>());
- expr = expr * rewriter.getAffineSymbolExpr(numSymbols++);
- }
-
- // Add the offset.
- if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) {
- expr = expr + attr.cast<IntegerAttr>().getInt();
- } else {
- dynamicOperands.push_back(mixedOffsets[index].get<Value>());
- expr = expr + rewriter.getAffineSymbolExpr(numSymbols++);
- }
- Location loc = subViewOp.getLoc();
- sourceIndices[index] = rewriter.create<AffineApplyOp>(
- loc, AffineMap::get(1, numSymbols, expr), dynamicOperands);
- }
- return success();
-}
-
-/// Helpers to access the memref operand for each op.
-template <typename LoadOrStoreOpTy>
-static Value getMemRefOperand(LoadOrStoreOpTy op) {
- return op.getMemref();
-}
-
-static Value getMemRefOperand(vector::TransferReadOp op) {
- return op.getSource();
-}
-
-static Value getMemRefOperand(vector::TransferWriteOp op) {
- return op.getSource();
-}
-
-/// Given the permutation map of the original
-/// `vector.transfer_read`/`vector.transfer_write` operations compute the
-/// permutation map to use after the subview is folded with it.
-static AffineMapAttr getPermutationMapAttr(MLIRContext *context,
- memref::SubViewOp subViewOp,
- AffineMap currPermutationMap) {
- llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
- SmallVector<AffineExpr> exprs;
- int64_t sourceRank = subViewOp.getSourceType().getRank();
- for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
- if (unusedDims.test(dim))
- continue;
- exprs.push_back(getAffineDimExpr(dim, context));
- }
- auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context);
- return AffineMapAttr::get(
- currPermutationMap.compose(resultDimToSourceDimMap));
-}
-
-//===----------------------------------------------------------------------===//
-// Patterns
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Merges subview operation with load/transferRead operation.
-template <typename OpTy>
-class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
-public:
- using OpRewritePattern<OpTy>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(OpTy loadOp,
- PatternRewriter &rewriter) const override;
-
-private:
- void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp,
- ArrayRef<Value> sourceIndices,
- PatternRewriter &rewriter) const;
-};
-
-/// Merges subview operation with store/transferWriteOp operation.
-template <typename OpTy>
-class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
-public:
- using OpRewritePattern<OpTy>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(OpTy storeOp,
- PatternRewriter &rewriter) const override;
-
-private:
- void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp,
- ArrayRef<Value> sourceIndices,
- PatternRewriter &rewriter) const;
-};
-
-template <typename LoadOpTy>
-void LoadOpOfSubViewFolder<LoadOpTy>::replaceOp(
- LoadOpTy loadOp, memref::SubViewOp subViewOp, ArrayRef<Value> sourceIndices,
- PatternRewriter &rewriter) const {
- rewriter.replaceOpWithNewOp<LoadOpTy>(loadOp, subViewOp.getSource(),
- sourceIndices);
-}
-
-template <>
-void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
- vector::TransferReadOp transferReadOp, memref::SubViewOp subViewOp,
- ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
- // TODO: support 0-d corner case.
- if (transferReadOp.getTransferRank() == 0)
- return;
- rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
- transferReadOp, transferReadOp.getVectorType(), subViewOp.getSource(),
- sourceIndices,
- getPermutationMapAttr(rewriter.getContext(), subViewOp,
- transferReadOp.getPermutationMap()),
- transferReadOp.getPadding(),
- /*mask=*/Value(), transferReadOp.getInBoundsAttr());
-}
-
-template <typename StoreOpTy>
-void StoreOpOfSubViewFolder<StoreOpTy>::replaceOp(
- StoreOpTy storeOp, memref::SubViewOp subViewOp,
- ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
- rewriter.replaceOpWithNewOp<StoreOpTy>(storeOp, storeOp.getValue(),
- subViewOp.getSource(), sourceIndices);
-}
-
-template <>
-void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
- vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp,
- ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
- // TODO: support 0-d corner case.
- if (transferWriteOp.getTransferRank() == 0)
- return;
- rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- transferWriteOp, transferWriteOp.getVector(), subViewOp.getSource(),
- sourceIndices,
- getPermutationMapAttr(rewriter.getContext(), subViewOp,
- transferWriteOp.getPermutationMap()),
- transferWriteOp.getInBoundsAttr());
-}
-} // namespace
-
-template <typename OpTy>
-LogicalResult
-LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
- PatternRewriter &rewriter) const {
- auto subViewOp =
- getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
- if (!subViewOp)
- return failure();
-
- SmallVector<Value, 4> sourceIndices;
- if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
- loadOp.getIndices(), sourceIndices)))
- return failure();
-
- replaceOp(loadOp, subViewOp, sourceIndices, rewriter);
- return success();
-}
-
-template <typename OpTy>
-LogicalResult
-StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
- PatternRewriter &rewriter) const {
- auto subViewOp =
- getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
- if (!subViewOp)
- return failure();
-
- SmallVector<Value, 4> sourceIndices;
- if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
- storeOp.getIndices(), sourceIndices)))
- return failure();
-
- replaceOp(storeOp, subViewOp, sourceIndices, rewriter);
- return success();
-}
-
-void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) {
- patterns.add<LoadOpOfSubViewFolder<AffineLoadOp>,
- LoadOpOfSubViewFolder<memref::LoadOp>,
- LoadOpOfSubViewFolder<vector::TransferReadOp>,
- StoreOpOfSubViewFolder<AffineStoreOp>,
- StoreOpOfSubViewFolder<memref::StoreOp>,
- StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
- patterns.getContext());
-}
-
-//===----------------------------------------------------------------------===//
-// Pass registration
-//===----------------------------------------------------------------------===//
-
-namespace {
-
-struct FoldSubViewOpsPass final
- : public FoldSubViewOpsBase<FoldSubViewOpsPass> {
- void runOnOperation() override;
-};
-
-} // namespace
-
-void FoldSubViewOpsPass::runOnOperation() {
- RewritePatternSet patterns(&getContext());
- memref::populateFoldSubViewOpPatterns(patterns);
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
-}
-
-std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() {
- return std::make_unique<FoldSubViewOpsPass>();
-}
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 2d1bfc1e92a50..b9901b7af0ff1 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -8,6 +8,8 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
int64_t mlir::linearize(ArrayRef<int64_t> offsets, ArrayRef<int64_t> basis) {
@@ -42,3 +44,26 @@ llvm::SmallVector<int64_t, 4> mlir::getI64SubArray(ArrayAttr arrayAttr,
res.push_back((*it).getValue().getSExtValue());
return res;
}
+
+mlir::AffineExpr mlir::getLinearAffineExpr(ArrayRef<int64_t> basis,
+ mlir::Builder &b) {
+ AffineExpr resultExpr = b.getAffineDimExpr(0);
+ resultExpr = resultExpr * basis[0];
+ for (unsigned i = 1; i < basis.size(); i++)
+ resultExpr = resultExpr + b.getAffineDimExpr(i) * basis[i];
+ return resultExpr;
+}
+
+llvm::SmallVector<mlir::AffineExpr, 4>
+mlir::getDelinearizedAffineExpr(mlir::ArrayRef<int64_t> strides, Builder &b) {
+ AffineExpr resultExpr = b.getAffineDimExpr(0);
+ int64_t rank = strides.size();
+ SmallVector<AffineExpr, 4> vectorOffsets(rank);
+ vectorOffsets[0] = resultExpr.floorDiv(strides[0]);
+ resultExpr = resultExpr % strides[0];
+ for (unsigned i = 1; i < rank; i++) {
+ vectorOffsets[i] = resultExpr.floorDiv(strides[i]);
+ resultExpr = resultExpr % strides[i];
+ }
+ return vectorOffsets;
+}
diff --git a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
similarity index 64%
rename from mlir/test/Dialect/MemRef/fold-subview-ops.mlir
rename to mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index 28138e93aec9e..18c2b3f403e98 100644
--- a/mlir/test/Dialect/MemRef/fold-subview-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -fold-memref-subview-ops -split-input-file %s -o - | FileCheck %s
+// RUN: mlir-opt -fold-memref-alias-ops -split-input-file %s -o - | FileCheck %s
func.func @fold_static_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> f32 {
%0 = memref.subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]>
@@ -272,3 +272,154 @@ func.func @fold_static_stride_subview_with_affine_load_store(%arg0 : memref<12x3
// CHECK-NEXT: return
return %1 : f32
}
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 * 6 + d1)>
+// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
+// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32 {
+func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> f32 {
+ %0 = memref.expand_shape %arg0 [[0, 1], [2]] : memref<12x32xf32> into memref<2x6x32xf32>
+ %1 = affine.load %0[%arg1, %arg2, %arg3] : memref<2x6x32xf32>
+ return %1 : f32
+}
+// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]](%[[ARG1]], %[[ARG2]])
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG3]]] : memref<12x32xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (d0 floordiv 6)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 mod 6)>
+// CHECK-LABEL: @fold_static_stride_subview_with_affine_load_store_collapse_shape
+// CHECK-SAME: (%[[ARG0:.*]]: memref<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg0 : memref<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
+ %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<2x6x32xf32> into memref<12x32xf32>
+ %1 = affine.load %0[%arg1, %arg2] : memref<12x32xf32>
+ return %1 : f32
+}
+// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]](%[[ARG1]])
+// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]](%[[ARG1]])
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<2x6x32xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1, d2) -> (d0 * 6 + d1 * 3 + d2)>
+// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_3d
+// CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) -> f32 {
+func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: index) -> f32 {
+ %0 = memref.expand_shape %arg0 [[0, 1, 2], [3]] : memref<12x32xf32> into memref<2x2x3x32xf32>
+ %1 = affine.load %0[%arg1, %arg2, %arg3, %arg4] : memref<2x2x3x32xf32>
+ return %1 : f32
+}
+// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]](%[[ARG1]], %[[ARG2]], %[[ARG3]])
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG4]]] : memref<12x32xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
+func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
+ %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
+ affine.for %arg3 = 0 to 1 {
+ affine.for %arg4 = 0 to 1024 {
+ affine.for %arg5 = 0 to 1020 {
+ affine.for %arg6 = 0 to 1 {
+ %1 = affine.load %0[%arg3, %arg4, %arg5, %arg6] : memref<1x1024x1024x1xf32>
+ affine.store %1, %arg1[%arg2] : memref<1xf32>
+ }
+ }
+ }
+ }
+ %2 = affine.load %arg1[%arg2] : memref<1xf32>
+ return %2 : f32
+}
+// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 1 {
+// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
+// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
+// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
+// CHECK-NEXT: %[[IDX1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])
+// CHECK-NEXT: %[[IDX2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
+// CHECK-NEXT: affine.load %[[ARG0]][%[[IDX1]], %[[IDX2]]] : memref<1024x1024xf32>
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1 + d0)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
+func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
+ %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
+ affine.for %arg3 = 0 to 1 {
+ affine.for %arg4 = 0 to 1024 {
+ affine.for %arg5 = 0 to 1020 {
+ affine.for %arg6 = 0 to 1 {
+ %1 = affine.load %0[%arg3, %arg4 + %arg3, %arg5, %arg6] : memref<1x1024x1024x1xf32>
+ affine.store %1, %arg1[%arg2] : memref<1xf32>
+ }
+ }
+ }
+ }
+ %2 = affine.load %arg1[%arg2] : memref<1xf32>
+ return %2 : f32
+}
+// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 1 {
+// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
+// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
+// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
+// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]], %[[ARG5]], %[[ARG6]])
+// CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]](%[[ARG3]], %[[TMP1]])
+// CHECK-NEXT: %[[TMP3:.*]] = affine.apply #map2(%[[ARG5]], %[[ARG6]])
+// CHECK-NEXT: affine.load %[[ARG0]][%[[TMP2]], %[[TMP3]]] : memref<1024x1024xf32>
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
+// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index)
+func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 {
+ %0 = memref.expand_shape %arg0 [[0, 1], [2, 3]] : memref<1024x1024xf32> into memref<1x1024x1024x1xf32>
+ %cst = arith.constant 0 : index
+ affine.for %arg3 = 0 to 1 {
+ affine.for %arg4 = 0 to 1024 {
+ affine.for %arg5 = 0 to 1020 {
+ affine.for %arg6 = 0 to 1 {
+ %1 = memref.load %0[%arg3, %cst, %arg5, %arg6] : memref<1x1024x1024x1xf32>
+ memref.store %1, %arg1[%arg2] : memref<1xf32>
+ }
+ }
+ }
+ }
+ %2 = memref.load %arg1[%arg2] : memref<1xf32>
+ return %2 : f32
+}
+// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : index
+// CHECK-NEXT: affine.for %[[ARG3:.*]] = 0 to 1 {
+// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 {
+// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 {
+// CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 {
+// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ZERO]])
+// CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]])
+// CHECK-NEXT: memref.load %[[ARG0]][%[[TMP1]], %[[TMP2]]] : memref<1024x1024xf32>
+
+// -----
+
+// CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_collapse_shape_with_0d_result
+// CHECK-SAME: (%[[ARG0:.*]]: memref<1xf32>, %[[ARG1:.*]]: memref<1xf32>)
+func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape_with_0d_result(%arg0: memref<1xf32>, %arg1: memref<1xf32>) -> memref<1xf32> {
+ %0 = memref.collapse_shape %arg0 [] : memref<1xf32> into memref<f32>
+ affine.for %arg2 = 0 to 3 {
+ %1 = affine.load %0[] : memref<f32>
+ affine.store %1, %arg1[0] : memref<1xf32>
+ }
+ return %arg1 : memref<1xf32>
+}
+// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : index
+// CHECK-NEXT: affine.for %{{.*}} = 0 to 3 {
+// CHECK-NEXT: affine.load %[[ARG0]][%[[ZERO]]] : memref<1xf32>
diff --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
index d942ef96f4c4f..8dbbd17ba0a08 100644
--- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
+++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
@@ -46,7 +46,7 @@ static LogicalResult runMLIRPasses(ModuleOp module) {
applyPassManagerCLOptions(passManager);
passManager.addPass(createGpuKernelOutliningPass());
- passManager.addPass(memref::createFoldSubViewOpsPass());
+ passManager.addPass(memref::createFoldMemRefAliasOpsPass());
passManager.addPass(createConvertGPUToSPIRVPass(/*mapMemorySpace=*/true));
OpPassManager &modulePM = passManager.nest<spirv::ModuleOp>();
More information about the Mlir-commits
mailing list