[Mlir-commits] [mlir] 7cde516 - [mlir][vector] NFC, move some vector patterns in a separate file
Thomas Raoux
llvmlistbot at llvm.org
Fri Nov 19 10:40:11 PST 2021
Author: Thomas Raoux
Date: 2021-11-19T10:39:29-08:00
New Revision: 7cde5165131f1268a8506066275ef7938c58d156
URL: https://github.com/llvm/llvm-project/commit/7cde5165131f1268a8506066275ef7938c58d156
DIFF: https://github.com/llvm/llvm-project/commit/7cde5165131f1268a8506066275ef7938c58d156.diff
LOG: [mlir][vector] NFC, move some vector patterns in a separate file
Move patterns related to dropping lead unit dim into their own file.
Differential Revision: https://reviews.llvm.org/D114265
Added:
mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp
Modified:
mlir/lib/Dialect/Vector/CMakeLists.txt
mlir/lib/Dialect/Vector/VectorTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt
index abd961626e59..8f01eda3de4f 100644
--- a/mlir/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRVector
+ VectorDropLeadUnitDim.cpp
VectorInsertExtractStridedSliceRewritePatterns.cpp
VectorMultiDimReductionTransforms.cpp
VectorOps.cpp
diff --git a/mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp
new file mode 100644
index 000000000000..f00a4f808c69
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/VectorDropLeadUnitDim.cpp
@@ -0,0 +1,259 @@
+//===- VectorDropLeadUnitDim.cpp - Conversion within the Vector dialect ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
+#include "mlir/Dialect/Vector/VectorUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/TypeUtilities.h"
+
+#define DEBUG_TYPE "vector-drop-unit-dim"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+// Trims leading one dimensions from `oldType` and returns the result type.
+// Returns `vector<1xT>` if `oldType` only has one element.
+static VectorType trimLeadingOneDims(VectorType oldType) {
+ ArrayRef<int64_t> oldShape = oldType.getShape();
+ ArrayRef<int64_t> newShape =
+ oldShape.drop_while([](int64_t dim) { return dim == 1; });
+ // Make sure we have at least 1 dimension per vector type requirements.
+ if (newShape.empty())
+ newShape = oldShape.take_back();
+ return VectorType::get(newShape, oldType.getElementType());
+}
+
+/// Return a smallVector of size `rank` containing all zeros.
+static SmallVector<int64_t> splatZero(int64_t rank) {
+ return SmallVector<int64_t>(rank, 0);
+}
+namespace {
+
+// Casts away leading one dimensions in vector.extract_strided_slice's vector
+// input by inserting vector.shape_cast.
+struct CastAwayExtractStridedSliceLeadingOneDim
+ : public OpRewritePattern<vector::ExtractStridedSliceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
+ PatternRewriter &rewriter) const override {
+ // vector.extract_strided_slice requires the input and output vector to have
+ // the same rank. Here we drop leading one dimensions from the input vector
+ // type to make sure we don't cause mismatch.
+ VectorType oldSrcType = extractOp.getVectorType();
+ VectorType newSrcType = trimLeadingOneDims(oldSrcType);
+
+ if (newSrcType.getRank() == oldSrcType.getRank())
+ return failure();
+
+ int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
+
+ VectorType oldDstType = extractOp.getType();
+ VectorType newDstType =
+ VectorType::get(oldDstType.getShape().drop_front(dropCount),
+ oldDstType.getElementType());
+
+ Location loc = extractOp.getLoc();
+
+ Value newSrcVector = rewriter.create<vector::ExtractOp>(
+ loc, extractOp.vector(), splatZero(dropCount));
+
+ // The offsets/sizes/strides attribute can have a less number of elements
+ // than the input vector's rank: it is meant for the leading dimensions.
+ auto newOffsets = rewriter.getArrayAttr(
+ extractOp.offsets().getValue().drop_front(dropCount));
+ auto newSizes = rewriter.getArrayAttr(
+ extractOp.sizes().getValue().drop_front(dropCount));
+ auto newStrides = rewriter.getArrayAttr(
+ extractOp.strides().getValue().drop_front(dropCount));
+
+ auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
+
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
+ newExtractOp);
+
+ return success();
+ }
+};
+
+// Casts away leading one dimensions in vector.extract_strided_slice's vector
+// inputs by inserting vector.shape_cast.
+struct CastAwayInsertStridedSliceLeadingOneDim
+ : public OpRewritePattern<vector::InsertStridedSliceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
+ PatternRewriter &rewriter) const override {
+ VectorType oldSrcType = insertOp.getSourceVectorType();
+ VectorType newSrcType = trimLeadingOneDims(oldSrcType);
+ VectorType oldDstType = insertOp.getDestVectorType();
+ VectorType newDstType = trimLeadingOneDims(oldDstType);
+
+ int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
+ int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
+ if (srcDropCount == 0 && dstDropCount == 0)
+ return failure();
+
+ // Trim leading one dimensions from both operands.
+ Location loc = insertOp.getLoc();
+
+ Value newSrcVector = rewriter.create<vector::ExtractOp>(
+ loc, insertOp.source(), splatZero(srcDropCount));
+ Value newDstVector = rewriter.create<vector::ExtractOp>(
+ loc, insertOp.dest(), splatZero(dstDropCount));
+
+ auto newOffsets = rewriter.getArrayAttr(
+ insertOp.offsets().getValue().take_back(newDstType.getRank()));
+ auto newStrides = rewriter.getArrayAttr(
+ insertOp.strides().getValue().take_back(newSrcType.getRank()));
+
+ auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
+
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
+ newInsertOp);
+
+ return success();
+ }
+};
+
+// Turns vector.transfer_read on vector with leading 1 dimensions into
+// vector.shape_cast followed by vector.transfer_read on vector without leading
+// 1 dimensions.
+struct CastAwayTransferReadLeadingOneDim
+ : public OpRewritePattern<vector::TransferReadOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp read,
+ PatternRewriter &rewriter) const override {
+ if (read.mask())
+ return failure();
+
+ auto shapedType = read.source().getType().cast<ShapedType>();
+ if (shapedType.getElementType() != read.getVectorType().getElementType())
+ return failure();
+
+ VectorType oldType = read.getVectorType();
+ VectorType newType = trimLeadingOneDims(oldType);
+
+ if (newType == oldType)
+ return failure();
+
+ AffineMap oldMap = read.permutation_map();
+ ArrayRef<AffineExpr> newResults =
+ oldMap.getResults().take_back(newType.getRank());
+ AffineMap newMap =
+ AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
+ rewriter.getContext());
+
+ ArrayAttr inBounds;
+ if (read.in_bounds())
+ inBounds = rewriter.getArrayAttr(
+ read.in_boundsAttr().getValue().take_back(newType.getRank()));
+
+ auto newRead = rewriter.create<vector::TransferReadOp>(
+ read.getLoc(), newType, read.source(), read.indices(), newMap,
+ read.padding(), inBounds);
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
+
+ return success();
+ }
+};
+
+// Turns vector.transfer_write on vector with leading 1 dimensions into
+// vector.shape_cast followed by vector.transfer_write on vector without leading
+// 1 dimensions.
+struct CastAwayTransferWriteLeadingOneDim
+ : public OpRewritePattern<vector::TransferWriteOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferWriteOp write,
+ PatternRewriter &rewriter) const override {
+ if (write.mask())
+ return failure();
+
+ auto shapedType = write.source().getType().dyn_cast<ShapedType>();
+ if (shapedType.getElementType() != write.getVectorType().getElementType())
+ return failure();
+
+ VectorType oldType = write.getVectorType();
+ VectorType newType = trimLeadingOneDims(oldType);
+ if (newType == oldType)
+ return failure();
+ int64_t dropDim = oldType.getRank() - newType.getRank();
+
+ AffineMap oldMap = write.permutation_map();
+ ArrayRef<AffineExpr> newResults =
+ oldMap.getResults().take_back(newType.getRank());
+ AffineMap newMap =
+ AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
+ rewriter.getContext());
+
+ ArrayAttr inBounds;
+ if (write.in_bounds())
+ inBounds = rewriter.getArrayAttr(
+ write.in_boundsAttr().getValue().take_back(newType.getRank()));
+
+ auto newVector = rewriter.create<vector::ExtractOp>(
+ write.getLoc(), write.vector(), splatZero(dropDim));
+ rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+ write, newVector, write.source(), write.indices(), newMap, inBounds);
+
+ return success();
+ }
+};
+
+class CastAwayElementwiseLeadingOneDim : public RewritePattern {
+public:
+ CastAwayElementwiseLeadingOneDim(MLIRContext *context)
+ : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
+ return failure();
+ auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>();
+ if (!vecType)
+ return failure();
+ VectorType newVecType = trimLeadingOneDims(vecType);
+ if (newVecType == vecType)
+ return failure();
+ int64_t dropDim = vecType.getRank() - newVecType.getRank();
+ SmallVector<Value, 4> newOperands;
+ for (Value operand : op->getOperands()) {
+ if (auto opVecType = operand.getType().dyn_cast<VectorType>()) {
+ newOperands.push_back(rewriter.create<vector::ExtractOp>(
+ op->getLoc(), operand, splatZero(dropDim)));
+ } else {
+ newOperands.push_back(operand);
+ }
+ }
+ OperationState state(op->getLoc(), op->getName());
+ state.addAttributes(op->getAttrs());
+ state.addOperands(newOperands);
+ state.addTypes(newVecType);
+ Operation *newOp = rewriter.createOperation(state);
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
+ newOp->getResult(0));
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<CastAwayExtractStridedSliceLeadingOneDim,
+ CastAwayInsertStridedSliceLeadingOneDim,
+ CastAwayTransferReadLeadingOneDim,
+ CastAwayTransferWriteLeadingOneDim,
+ CastAwayElementwiseLeadingOneDim>(patterns.getContext());
+ populateShapeCastFoldingPatterns(patterns);
+}
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 8a77f171c61b..37f3c31e6a48 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2931,234 +2931,6 @@ struct TransferWriteToVectorStoreLowering
llvm::Optional<unsigned> maxTransferRank;
};
-// Trims leading one dimensions from `oldType` and returns the result type.
-// Returns `vector<1xT>` if `oldType` only has one element.
-static VectorType trimLeadingOneDims(VectorType oldType) {
- ArrayRef<int64_t> oldShape = oldType.getShape();
- ArrayRef<int64_t> newShape =
- oldShape.drop_while([](int64_t dim) { return dim == 1; });
- // Make sure we have at least 1 dimension per vector type requirements.
- if (newShape.empty())
- newShape = oldShape.take_back();
- return VectorType::get(newShape, oldType.getElementType());
-}
-
-/// Return a smallVector of size `rank` containing all zeros.
-static SmallVector<int64_t> splatZero(int64_t rank) {
- return SmallVector<int64_t>(rank, 0);
-}
-
-// Casts away leading one dimensions in vector.extract_strided_slice's vector
-// input by inserting vector.shape_cast.
-struct CastAwayExtractStridedSliceLeadingOneDim
- : public OpRewritePattern<vector::ExtractStridedSliceOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
- PatternRewriter &rewriter) const override {
- // vector.extract_strided_slice requires the input and output vector to have
- // the same rank. Here we drop leading one dimensions from the input vector
- // type to make sure we don't cause mismatch.
- VectorType oldSrcType = extractOp.getVectorType();
- VectorType newSrcType = trimLeadingOneDims(oldSrcType);
-
- if (newSrcType.getRank() == oldSrcType.getRank())
- return failure();
-
- int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank();
-
- VectorType oldDstType = extractOp.getType();
- VectorType newDstType =
- VectorType::get(oldDstType.getShape().drop_front(dropCount),
- oldDstType.getElementType());
-
- Location loc = extractOp.getLoc();
-
- Value newSrcVector = rewriter.create<vector::ExtractOp>(
- loc, extractOp.vector(), splatZero(dropCount));
-
- // The offsets/sizes/strides attribute can have a less number of elements
- // than the input vector's rank: it is meant for the leading dimensions.
- auto newOffsets = rewriter.getArrayAttr(
- extractOp.offsets().getValue().drop_front(dropCount));
- auto newSizes = rewriter.getArrayAttr(
- extractOp.sizes().getValue().drop_front(dropCount));
- auto newStrides = rewriter.getArrayAttr(
- extractOp.strides().getValue().drop_front(dropCount));
-
- auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
-
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
- newExtractOp);
-
- return success();
- }
-};
-
-// Casts away leading one dimensions in vector.extract_strided_slice's vector
-// inputs by inserting vector.shape_cast.
-struct CastAwayInsertStridedSliceLeadingOneDim
- : public OpRewritePattern<vector::InsertStridedSliceOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
- PatternRewriter &rewriter) const override {
- VectorType oldSrcType = insertOp.getSourceVectorType();
- VectorType newSrcType = trimLeadingOneDims(oldSrcType);
- VectorType oldDstType = insertOp.getDestVectorType();
- VectorType newDstType = trimLeadingOneDims(oldDstType);
-
- int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
- int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
- if (srcDropCount == 0 && dstDropCount == 0)
- return failure();
-
- // Trim leading one dimensions from both operands.
- Location loc = insertOp.getLoc();
-
- Value newSrcVector = rewriter.create<vector::ExtractOp>(
- loc, insertOp.source(), splatZero(srcDropCount));
- Value newDstVector = rewriter.create<vector::ExtractOp>(
- loc, insertOp.dest(), splatZero(dstDropCount));
-
- auto newOffsets = rewriter.getArrayAttr(
- insertOp.offsets().getValue().take_back(newDstType.getRank()));
- auto newStrides = rewriter.getArrayAttr(
- insertOp.strides().getValue().take_back(newSrcType.getRank()));
-
- auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
- loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
-
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
- newInsertOp);
-
- return success();
- }
-};
-
-// Turns vector.transfer_read on vector with leading 1 dimensions into
-// vector.shape_cast followed by vector.transfer_read on vector without leading
-// 1 dimensions.
-struct CastAwayTransferReadLeadingOneDim
- : public OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::TransferReadOp read,
- PatternRewriter &rewriter) const override {
- if (read.mask())
- return failure();
-
- auto shapedType = read.source().getType().cast<ShapedType>();
- if (shapedType.getElementType() != read.getVectorType().getElementType())
- return failure();
-
- VectorType oldType = read.getVectorType();
- VectorType newType = trimLeadingOneDims(oldType);
-
- if (newType == oldType)
- return failure();
-
- AffineMap oldMap = read.permutation_map();
- ArrayRef<AffineExpr> newResults =
- oldMap.getResults().take_back(newType.getRank());
- AffineMap newMap =
- AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
- rewriter.getContext());
-
- ArrayAttr inBounds;
- if (read.in_bounds())
- inBounds = rewriter.getArrayAttr(
- read.in_boundsAttr().getValue().take_back(newType.getRank()));
-
- auto newRead = rewriter.create<vector::TransferReadOp>(
- read.getLoc(), newType, read.source(), read.indices(), newMap,
- read.padding(), inBounds);
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
-
- return success();
- }
-};
-
-// Turns vector.transfer_write on vector with leading 1 dimensions into
-// vector.shape_cast followed by vector.transfer_write on vector without leading
-// 1 dimensions.
-struct CastAwayTransferWriteLeadingOneDim
- : public OpRewritePattern<vector::TransferWriteOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::TransferWriteOp write,
- PatternRewriter &rewriter) const override {
- if (write.mask())
- return failure();
-
- auto shapedType = write.source().getType().dyn_cast<ShapedType>();
- if (shapedType.getElementType() != write.getVectorType().getElementType())
- return failure();
-
- VectorType oldType = write.getVectorType();
- VectorType newType = trimLeadingOneDims(oldType);
- if (newType == oldType)
- return failure();
- int64_t dropDim = oldType.getRank() - newType.getRank();
-
- AffineMap oldMap = write.permutation_map();
- ArrayRef<AffineExpr> newResults =
- oldMap.getResults().take_back(newType.getRank());
- AffineMap newMap =
- AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults,
- rewriter.getContext());
-
- ArrayAttr inBounds;
- if (write.in_bounds())
- inBounds = rewriter.getArrayAttr(
- write.in_boundsAttr().getValue().take_back(newType.getRank()));
-
- auto newVector = rewriter.create<vector::ExtractOp>(
- write.getLoc(), write.vector(), splatZero(dropDim));
- rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- write, newVector, write.source(), write.indices(), newMap, inBounds);
-
- return success();
- }
-};
-
-class CastAwayElementwiseLeadingOneDim : public RewritePattern {
-public:
- CastAwayElementwiseLeadingOneDim(MLIRContext *context)
- : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
-
- LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override {
- if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
- return failure();
- auto vecType = op->getResultTypes()[0].dyn_cast<VectorType>();
- if (!vecType)
- return failure();
- VectorType newVecType = trimLeadingOneDims(vecType);
- if (newVecType == vecType)
- return failure();
- int64_t dropDim = vecType.getRank() - newVecType.getRank();
- SmallVector<Value, 4> newOperands;
- for (Value operand : op->getOperands()) {
- if (auto opVecType = operand.getType().dyn_cast<VectorType>()) {
- newOperands.push_back(rewriter.create<vector::ExtractOp>(
- op->getLoc(), operand, splatZero(dropDim)));
- } else {
- newOperands.push_back(operand);
- }
- }
- OperationState state(op->getLoc(), op->getName());
- state.addAttributes(op->getAttrs());
- state.addOperands(newOperands);
- state.addTypes(newVecType);
- Operation *newOp = rewriter.createOperation(state);
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
- newOp->getResult(0));
- return success();
- }
-};
-
// Returns the values in `arrayAttr` as an integer vector.
static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
return llvm::to_vector<4>(
@@ -3638,16 +3410,6 @@ void mlir::vector::populateShapeCastFoldingPatterns(
patterns.add<ShapeCastOpFolder>(patterns.getContext());
}
-void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
- RewritePatternSet &patterns) {
- patterns.add<CastAwayExtractStridedSliceLeadingOneDim,
- CastAwayInsertStridedSliceLeadingOneDim,
- CastAwayTransferReadLeadingOneDim,
- CastAwayTransferWriteLeadingOneDim,
- CastAwayElementwiseLeadingOneDim>(patterns.getContext());
- populateShapeCastFoldingPatterns(patterns);
-}
-
void mlir::vector::populateBubbleVectorBitCastOpPatterns(
RewritePatternSet &patterns) {
patterns.add<BubbleDownVectorBitCastForExtract,
More information about the Mlir-commits
mailing list